"""Text encoder preflight health check for gated Hugging Face access and local cache paths.""" from __future__ import annotations import argparse import json import os from huggingface_hub import HfApi, hf_hub_download from transformers import AutoConfig TEXT_ENCODER_PRESETS = { "llm2vec": { "base_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp", "peft_model_name_or_path": "McGill-NLP/LLM2Vec-Meta-Llama-31-8B-Instruct-mntp-supervised", } } def _get_hf_token() -> str | None: return ( os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HF_HUB_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") ) def _check_repo_access(repo_id: str, token: str) -> tuple[bool, str]: api = HfApi() try: api.model_info(repo_id=repo_id, token=token) return True, "ok" except Exception as error: # pragma: no cover - depends on runtime/network/auth return False, f"{type(error).__name__}: {error}" def _check_gated_base_access(repo_id: str, token: str) -> tuple[bool, str, str | None]: """Resolve adapter base model and verify config download entitlement.""" try: adapter_cfg_path = hf_hub_download(repo_id, "adapter_config.json", token=token) with open(adapter_cfg_path, "r", encoding="utf-8") as f: adapter_cfg = json.load(f) base_model = adapter_cfg.get("base_model_name_or_path") if not isinstance(base_model, str) or not base_model: return False, "adapter_config missing base_model_name_or_path", None AutoConfig.from_pretrained(base_model, token=token) return True, "ok", base_model except Exception as error: # pragma: no cover - depends on runtime/network/auth return False, f"{type(error).__name__}: {error}", None def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Kimodo text encoder health check") parser.add_argument( "--text-encoder", default="llm2vec", choices=sorted(TEXT_ENCODER_PRESETS.keys()), help="Text encoder preset to validate.", ) parser.add_argument( "--strict", action="store_true", help="Return non-zero if any check fails.", ) return parser.parse_args() def main() -> int: args = parse_args() preset = TEXT_ENCODER_PRESETS[args.text_encoder] base_repo = preset["base_model_name_or_path"] peft_repo = preset["peft_model_name_or_path"] token = _get_hf_token() text_encoders_dir = os.environ.get("TEXT_ENCODERS_DIR") report = { "text_encoder": args.text_encoder, "token_present": bool(token), "token_length": len(token) if token else 0, "text_encoders_dir": text_encoders_dir, "checks": {}, } failed = False if text_encoders_dir: base_path = os.path.join(text_encoders_dir, base_repo) peft_path = os.path.join(text_encoders_dir, peft_repo) base_ok = os.path.exists(base_path) peft_ok = os.path.exists(peft_path) report["checks"]["base_local_path"] = {"ok": base_ok, "path": base_path} report["checks"]["peft_local_path"] = {"ok": peft_ok, "path": peft_path} if not base_ok or not peft_ok: failed = True else: if not token: report["checks"]["token"] = { "ok": False, "error": "No HF token found in HF_TOKEN/HUGGING_FACE_HUB_TOKEN/HF_HUB_TOKEN/HUGGINGFACEHUB_API_TOKEN", } failed = True else: base_ok, base_error = _check_repo_access(base_repo, token) peft_ok, peft_error = _check_repo_access(peft_repo, token) report["checks"]["base_repo_access"] = {"ok": base_ok, "repo": base_repo, "detail": base_error} report["checks"]["peft_repo_access"] = {"ok": peft_ok, "repo": peft_repo, "detail": peft_error} gated_ok, gated_detail, gated_base = _check_gated_base_access(base_repo, token) report["checks"]["gated_base_config_access"] = { "ok": gated_ok, "adapter_repo": base_repo, "base_model": gated_base, "detail": gated_detail, } if not base_ok or not peft_ok: failed = True if not gated_ok: failed = True print(json.dumps(report, indent=2, sort_keys=True)) if args.strict and failed: return 2 return 0 if __name__ == "__main__": raise SystemExit(main())