movimento / kimodo /scripts /text_encoder_health.py
rydlrKE's picture
Switch to LLM2Vec 3.1 pair to fix Space 401 on gated Llama 3.0
560cef6
"""Text encoder preflight health check for gated Hugging Face access and local cache paths."""
from __future__ import annotations
import argparse
import json
import os
from huggingface_hub import HfApi, hf_hub_download
from transformers import AutoConfig
TEXT_ENCODER_PRESETS = {
"llm2vec": {
"base_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp",
"peft_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-supervised",
}
}
def _get_hf_token() -> str | None:
return (
os.environ.get("HF_TOKEN")
or os.environ.get("HUGGING_FACE_HUB_TOKEN")
or os.environ.get("HF_HUB_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
def _check_repo_access(repo_id: str, token: str) -> tuple[bool, str]:
api = HfApi()
try:
api.model_info(repo_id=repo_id, token=token)
return True, "ok"
except Exception as error: # pragma: no cover - depends on runtime/network/auth
return False, f"{type(error).__name__}: {error}"
def _check_gated_base_access(repo_id: str, token: str) -> tuple[bool, str, str | None]:
"""Resolve adapter base model and verify config download entitlement."""
try:
adapter_cfg_path = hf_hub_download(repo_id, "adapter_config.json", token=token)
with open(adapter_cfg_path, "r", encoding="utf-8") as f:
adapter_cfg = json.load(f)
base_model = adapter_cfg.get("base_model_name_or_path")
if not isinstance(base_model, str) or not base_model:
return False, "adapter_config missing base_model_name_or_path", None
AutoConfig.from_pretrained(base_model, token=token)
return True, "ok", base_model
except Exception as error: # pragma: no cover - depends on runtime/network/auth
return False, f"{type(error).__name__}: {error}", None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Kimodo text encoder health check")
parser.add_argument(
"--text-encoder",
default="llm2vec",
choices=sorted(TEXT_ENCODER_PRESETS.keys()),
help="Text encoder preset to validate.",
)
parser.add_argument(
"--strict",
action="store_true",
help="Return non-zero if any check fails.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
preset = TEXT_ENCODER_PRESETS[args.text_encoder]
base_repo = preset["base_model_name_or_path"]
peft_repo = preset["peft_model_name_or_path"]
token = _get_hf_token()
text_encoders_dir = os.environ.get("TEXT_ENCODERS_DIR")
report = {
"text_encoder": args.text_encoder,
"token_present": bool(token),
"token_length": len(token) if token else 0,
"text_encoders_dir": text_encoders_dir,
"checks": {},
}
failed = False
if text_encoders_dir:
base_path = os.path.join(text_encoders_dir, base_repo)
peft_path = os.path.join(text_encoders_dir, peft_repo)
base_ok = os.path.exists(base_path)
peft_ok = os.path.exists(peft_path)
report["checks"]["base_local_path"] = {"ok": base_ok, "path": base_path}
report["checks"]["peft_local_path"] = {"ok": peft_ok, "path": peft_path}
if not base_ok or not peft_ok:
failed = True
else:
if not token:
report["checks"]["token"] = {
"ok": False,
"error": "No HF token found in HF_TOKEN/HUGGING_FACE_HUB_TOKEN/HF_HUB_TOKEN/HUGGINGFACEHUB_API_TOKEN",
}
failed = True
else:
base_ok, base_error = _check_repo_access(base_repo, token)
peft_ok, peft_error = _check_repo_access(peft_repo, token)
report["checks"]["base_repo_access"] = {"ok": base_ok, "repo": base_repo, "detail": base_error}
report["checks"]["peft_repo_access"] = {"ok": peft_ok, "repo": peft_repo, "detail": peft_error}
gated_ok, gated_detail, gated_base = _check_gated_base_access(base_repo, token)
report["checks"]["gated_base_config_access"] = {
"ok": gated_ok,
"adapter_repo": base_repo,
"base_model": gated_base,
"detail": gated_detail,
}
if not base_ok or not peft_ok:
failed = True
if not gated_ok:
failed = True
print(json.dumps(report, indent=2, sort_keys=True))
if args.strict and failed:
return 2
return 0
if __name__ == "__main__":
raise SystemExit(main())