File size: 4,621 Bytes
4be5ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
560cef6
 
4be5ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""Text encoder preflight health check for gated Hugging Face access and local cache paths."""

from __future__ import annotations

import argparse
import json
import os

from huggingface_hub import HfApi, hf_hub_download
from transformers import AutoConfig


TEXT_ENCODER_PRESETS = {
    "llm2vec": {
        "base_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp",
        "peft_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-supervised",
    }
}


def _get_hf_token() -> str | None:
    return (
        os.environ.get("HF_TOKEN")
        or os.environ.get("HUGGING_FACE_HUB_TOKEN")
        or os.environ.get("HF_HUB_TOKEN")
        or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
    )


def _check_repo_access(repo_id: str, token: str) -> tuple[bool, str]:
    api = HfApi()
    try:
        api.model_info(repo_id=repo_id, token=token)
        return True, "ok"
    except Exception as error:  # pragma: no cover - depends on runtime/network/auth
        return False, f"{type(error).__name__}: {error}"


def _check_gated_base_access(repo_id: str, token: str) -> tuple[bool, str, str | None]:
    """Resolve adapter base model and verify config download entitlement."""
    try:
        adapter_cfg_path = hf_hub_download(repo_id, "adapter_config.json", token=token)
        with open(adapter_cfg_path, "r", encoding="utf-8") as f:
            adapter_cfg = json.load(f)
        base_model = adapter_cfg.get("base_model_name_or_path")
        if not isinstance(base_model, str) or not base_model:
            return False, "adapter_config missing base_model_name_or_path", None
        AutoConfig.from_pretrained(base_model, token=token)
        return True, "ok", base_model
    except Exception as error:  # pragma: no cover - depends on runtime/network/auth
        return False, f"{type(error).__name__}: {error}", None


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Kimodo text encoder health check")
    parser.add_argument(
        "--text-encoder",
        default="llm2vec",
        choices=sorted(TEXT_ENCODER_PRESETS.keys()),
        help="Text encoder preset to validate.",
    )
    parser.add_argument(
        "--strict",
        action="store_true",
        help="Return non-zero if any check fails.",
    )
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    preset = TEXT_ENCODER_PRESETS[args.text_encoder]
    base_repo = preset["base_model_name_or_path"]
    peft_repo = preset["peft_model_name_or_path"]

    token = _get_hf_token()
    text_encoders_dir = os.environ.get("TEXT_ENCODERS_DIR")

    report = {
        "text_encoder": args.text_encoder,
        "token_present": bool(token),
        "token_length": len(token) if token else 0,
        "text_encoders_dir": text_encoders_dir,
        "checks": {},
    }

    failed = False

    if text_encoders_dir:
        base_path = os.path.join(text_encoders_dir, base_repo)
        peft_path = os.path.join(text_encoders_dir, peft_repo)
        base_ok = os.path.exists(base_path)
        peft_ok = os.path.exists(peft_path)
        report["checks"]["base_local_path"] = {"ok": base_ok, "path": base_path}
        report["checks"]["peft_local_path"] = {"ok": peft_ok, "path": peft_path}
        if not base_ok or not peft_ok:
            failed = True
    else:
        if not token:
            report["checks"]["token"] = {
                "ok": False,
                "error": "No HF token found in HF_TOKEN/HUGGING_FACE_HUB_TOKEN/HF_HUB_TOKEN/HUGGINGFACEHUB_API_TOKEN",
            }
            failed = True
        else:
            base_ok, base_error = _check_repo_access(base_repo, token)
            peft_ok, peft_error = _check_repo_access(peft_repo, token)
            report["checks"]["base_repo_access"] = {"ok": base_ok, "repo": base_repo, "detail": base_error}
            report["checks"]["peft_repo_access"] = {"ok": peft_ok, "repo": peft_repo, "detail": peft_error}

            gated_ok, gated_detail, gated_base = _check_gated_base_access(base_repo, token)
            report["checks"]["gated_base_config_access"] = {
                "ok": gated_ok,
                "adapter_repo": base_repo,
                "base_model": gated_base,
                "detail": gated_detail,
            }

            if not base_ok or not peft_ok:
                failed = True
            if not gated_ok:
                failed = True

    print(json.dumps(report, indent=2, sort_keys=True))
    if args.strict and failed:
        return 2
    return 0


if __name__ == "__main__":
    raise SystemExit(main())