Spaces:
Running on Zero
Running on Zero
File size: 4,621 Bytes
4be5ba2 560cef6 4be5ba2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | """Text encoder preflight health check for gated Hugging Face access and local cache paths."""
from __future__ import annotations
import argparse
import json
import os
from huggingface_hub import HfApi, hf_hub_download
from transformers import AutoConfig
TEXT_ENCODER_PRESETS = {
"llm2vec": {
"base_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp",
"peft_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-supervised",
}
}
def _get_hf_token() -> str | None:
return (
os.environ.get("HF_TOKEN")
or os.environ.get("HUGGING_FACE_HUB_TOKEN")
or os.environ.get("HF_HUB_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
def _check_repo_access(repo_id: str, token: str) -> tuple[bool, str]:
api = HfApi()
try:
api.model_info(repo_id=repo_id, token=token)
return True, "ok"
except Exception as error: # pragma: no cover - depends on runtime/network/auth
return False, f"{type(error).__name__}: {error}"
def _check_gated_base_access(repo_id: str, token: str) -> tuple[bool, str, str | None]:
"""Resolve adapter base model and verify config download entitlement."""
try:
adapter_cfg_path = hf_hub_download(repo_id, "adapter_config.json", token=token)
with open(adapter_cfg_path, "r", encoding="utf-8") as f:
adapter_cfg = json.load(f)
base_model = adapter_cfg.get("base_model_name_or_path")
if not isinstance(base_model, str) or not base_model:
return False, "adapter_config missing base_model_name_or_path", None
AutoConfig.from_pretrained(base_model, token=token)
return True, "ok", base_model
except Exception as error: # pragma: no cover - depends on runtime/network/auth
return False, f"{type(error).__name__}: {error}", None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Kimodo text encoder health check")
parser.add_argument(
"--text-encoder",
default="llm2vec",
choices=sorted(TEXT_ENCODER_PRESETS.keys()),
help="Text encoder preset to validate.",
)
parser.add_argument(
"--strict",
action="store_true",
help="Return non-zero if any check fails.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
preset = TEXT_ENCODER_PRESETS[args.text_encoder]
base_repo = preset["base_model_name_or_path"]
peft_repo = preset["peft_model_name_or_path"]
token = _get_hf_token()
text_encoders_dir = os.environ.get("TEXT_ENCODERS_DIR")
report = {
"text_encoder": args.text_encoder,
"token_present": bool(token),
"token_length": len(token) if token else 0,
"text_encoders_dir": text_encoders_dir,
"checks": {},
}
failed = False
if text_encoders_dir:
base_path = os.path.join(text_encoders_dir, base_repo)
peft_path = os.path.join(text_encoders_dir, peft_repo)
base_ok = os.path.exists(base_path)
peft_ok = os.path.exists(peft_path)
report["checks"]["base_local_path"] = {"ok": base_ok, "path": base_path}
report["checks"]["peft_local_path"] = {"ok": peft_ok, "path": peft_path}
if not base_ok or not peft_ok:
failed = True
else:
if not token:
report["checks"]["token"] = {
"ok": False,
"error": "No HF token found in HF_TOKEN/HUGGING_FACE_HUB_TOKEN/HF_HUB_TOKEN/HUGGINGFACEHUB_API_TOKEN",
}
failed = True
else:
base_ok, base_error = _check_repo_access(base_repo, token)
peft_ok, peft_error = _check_repo_access(peft_repo, token)
report["checks"]["base_repo_access"] = {"ok": base_ok, "repo": base_repo, "detail": base_error}
report["checks"]["peft_repo_access"] = {"ok": peft_ok, "repo": peft_repo, "detail": peft_error}
gated_ok, gated_detail, gated_base = _check_gated_base_access(base_repo, token)
report["checks"]["gated_base_config_access"] = {
"ok": gated_ok,
"adapter_repo": base_repo,
"base_model": gated_base,
"detail": gated_detail,
}
if not base_ok or not peft_ok:
failed = True
if not gated_ok:
failed = True
print(json.dumps(report, indent=2, sort_keys=True))
if args.strict and failed:
return 2
return 0
if __name__ == "__main__":
raise SystemExit(main())
|