""" Loads the AVION text encoder from HF Hub. In demo mode we only need the text encoder — video embeddings are pre-computed. Singleton — loaded once per pod lifetime. """ import os import sys import torch from collections import OrderedDict from huggingface_hub import hf_hub_download from pathlib import Path sys.path.insert(0, str(Path(__file__).parent)) from avion.models.model_clip import CLIP_VITL14 _model = None # Maps FlashMHA key names (checkpoint saved with use_flash_attn=True) # to standard nn.MultiheadAttention names. _KEY_MAP = { ".attn.Wqkv.weight": ".attn.in_proj_weight", ".attn.Wqkv.bias": ".attn.in_proj_bias", ".mlp.fc1.weight": ".mlp.c_fc.weight", ".mlp.fc1.bias": ".mlp.c_fc.bias", ".mlp.fc2.weight": ".mlp.c_proj.weight", ".mlp.fc2.bias": ".mlp.c_proj.bias", } def _remap_flash_attn_keys(state_dict: OrderedDict) -> OrderedDict: remapped = OrderedDict() for k, v in state_dict.items(): new_k = k for old_suffix, new_suffix in _KEY_MAP.items(): if k.endswith(old_suffix): new_k = k[: -len(old_suffix)] + new_suffix break remapped[new_k] = v return remapped def get_model(): global _model if _model is not None: return _model hf_token = os.environ.get("HF_TOKEN") repo_id = os.environ.get("MODEL_REPO_ID", "jsurrea/avion-vitl-ek100-sms") # Try lightweight text encoder first, fall back to full checkpoint try: print(f"Downloading text_encoder.pt from {repo_id}...") ckpt_path = hf_hub_download( repo_id=repo_id, filename="text_encoder.pt", token=hf_token, ) ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) text_state = ckpt.get("state_dict", ckpt) embed_dim = ckpt.get("embed_dim", 256) text_only = True print(f" Text encoder loaded ({len(text_state)} keys, embed_dim={embed_dim})") except Exception as e: print(f"text_encoder.pt not found ({e}), falling back to full checkpoint...") ckpt_path = hf_hub_download( repo_id=repo_id, filename="checkpoint_round_2.pt", token=hf_token, ) ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) raw_state = ckpt.get("state_dict", ckpt) text_state = _remap_flash_attn_keys( OrderedDict({k.replace("module.", ""): v for k, v in raw_state.items()}) ) embed_dim = 256 text_only = False print(f" Full checkpoint loaded (epoch {ckpt.get('epoch', '?')})") model = CLIP_VITL14( freeze_temperature=True, use_grad_checkpointing=False, context_length=77, vocab_size=49408, patch_dropout=0.0, num_frames=16, drop_path_rate=0.0, use_fast_conv1=True, use_flash_attn=False, use_quick_gelu=False, project_embed_dim=embed_dim, pretrain_zoo=None, ) missing, unexpected = model.load_state_dict(text_state, strict=False) non_vision_missing = [k for k in missing if not k.startswith("visual.") and k != "logit_scale"] if non_vision_missing: print(f" WARNING — text encoder keys missing: {non_vision_missing[:5]}") if not text_only and (missing or unexpected): print(f" Missing: {len(missing)}, Unexpected: {len(unexpected)}") model.eval() _model = model print(" Model ready.") return _model