movimento / kimodo /scripts /text_encoder_health.py
rydlrKE's picture
Switch to LLM2Vec 3.1 pair to fix Space 401 on gated Llama 3.0
560cef6
raw
history blame
4.62 kB
"""Text encoder preflight health check for gated Hugging Face access and local cache paths."""
from __future__ import annotations
import argparse
import json
import os
from huggingface_hub import HfApi, hf_hub_download
from transformers import AutoConfig
TEXT_ENCODER_PRESETS = {
"llm2vec": {
"base_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp",
"peft_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-supervised",
}
}
def _get_hf_token() -> str | None:
return (
os.environ.get("HF_TOKEN")
or os.environ.get("HUGGING_FACE_HUB_TOKEN")
or os.environ.get("HF_HUB_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
def _check_repo_access(repo_id: str, token: str) -> tuple[bool, str]:
api = HfApi()
try:
api.model_info(repo_id=repo_id, token=token)
return True, "ok"
except Exception as error: # pragma: no cover - depends on runtime/network/auth
return False, f"{type(error).__name__}: {error}"
def _check_gated_base_access(repo_id: str, token: str) -> tuple[bool, str, str | None]:
"""Resolve adapter base model and verify config download entitlement."""
try:
adapter_cfg_path = hf_hub_download(repo_id, "adapter_config.json", token=token)
with open(adapter_cfg_path, "r", encoding="utf-8") as f:
adapter_cfg = json.load(f)
base_model = adapter_cfg.get("base_model_name_or_path")
if not isinstance(base_model, str) or not base_model:
return False, "adapter_config missing base_model_name_or_path", None
AutoConfig.from_pretrained(base_model, token=token)
return True, "ok", base_model
except Exception as error: # pragma: no cover - depends on runtime/network/auth
return False, f"{type(error).__name__}: {error}", None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Kimodo text encoder health check")
parser.add_argument(
"--text-encoder",
default="llm2vec",
choices=sorted(TEXT_ENCODER_PRESETS.keys()),
help="Text encoder preset to validate.",
)
parser.add_argument(
"--strict",
action="store_true",
help="Return non-zero if any check fails.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
preset = TEXT_ENCODER_PRESETS[args.text_encoder]
base_repo = preset["base_model_name_or_path"]
peft_repo = preset["peft_model_name_or_path"]
token = _get_hf_token()
text_encoders_dir = os.environ.get("TEXT_ENCODERS_DIR")
report = {
"text_encoder": args.text_encoder,
"token_present": bool(token),
"token_length": len(token) if token else 0,
"text_encoders_dir": text_encoders_dir,
"checks": {},
}
failed = False
if text_encoders_dir:
base_path = os.path.join(text_encoders_dir, base_repo)
peft_path = os.path.join(text_encoders_dir, peft_repo)
base_ok = os.path.exists(base_path)
peft_ok = os.path.exists(peft_path)
report["checks"]["base_local_path"] = {"ok": base_ok, "path": base_path}
report["checks"]["peft_local_path"] = {"ok": peft_ok, "path": peft_path}
if not base_ok or not peft_ok:
failed = True
else:
if not token:
report["checks"]["token"] = {
"ok": False,
"error": "No HF token found in HF_TOKEN/HUGGING_FACE_HUB_TOKEN/HF_HUB_TOKEN/HUGGINGFACEHUB_API_TOKEN",
}
failed = True
else:
base_ok, base_error = _check_repo_access(base_repo, token)
peft_ok, peft_error = _check_repo_access(peft_repo, token)
report["checks"]["base_repo_access"] = {"ok": base_ok, "repo": base_repo, "detail": base_error}
report["checks"]["peft_repo_access"] = {"ok": peft_ok, "repo": peft_repo, "detail": peft_error}
gated_ok, gated_detail, gated_base = _check_gated_base_access(base_repo, token)
report["checks"]["gated_base_config_access"] = {
"ok": gated_ok,
"adapter_repo": base_repo,
"base_model": gated_base,
"detail": gated_detail,
}
if not base_ok or not peft_ok:
failed = True
if not gated_ok:
failed = True
print(json.dumps(report, indent=2, sort_keys=True))
if args.strict and failed:
return 2
return 0
if __name__ == "__main__":
raise SystemExit(main())