Spaces:
Running
Running
| """ | |
| Loads the AVION text encoder from HF Hub. | |
| In demo mode we only need the text encoder — video embeddings are pre-computed. | |
| Singleton — loaded once per pod lifetime. | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| from collections import OrderedDict | |
| from huggingface_hub import hf_hub_download | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from avion.models.model_clip import CLIP_VITL14 | |
| _model = None | |
| # Maps FlashMHA key names (checkpoint saved with use_flash_attn=True) | |
| # to standard nn.MultiheadAttention names. | |
| _KEY_MAP = { | |
| ".attn.Wqkv.weight": ".attn.in_proj_weight", | |
| ".attn.Wqkv.bias": ".attn.in_proj_bias", | |
| ".mlp.fc1.weight": ".mlp.c_fc.weight", | |
| ".mlp.fc1.bias": ".mlp.c_fc.bias", | |
| ".mlp.fc2.weight": ".mlp.c_proj.weight", | |
| ".mlp.fc2.bias": ".mlp.c_proj.bias", | |
| } | |
| def _remap_flash_attn_keys(state_dict: OrderedDict) -> OrderedDict: | |
| remapped = OrderedDict() | |
| for k, v in state_dict.items(): | |
| new_k = k | |
| for old_suffix, new_suffix in _KEY_MAP.items(): | |
| if k.endswith(old_suffix): | |
| new_k = k[: -len(old_suffix)] + new_suffix | |
| break | |
| remapped[new_k] = v | |
| return remapped | |
| def get_model(): | |
| global _model | |
| if _model is not None: | |
| return _model | |
| hf_token = os.environ.get("HF_TOKEN") | |
| repo_id = os.environ.get("MODEL_REPO_ID", "jsurrea/avion-vitl-ek100-sms") | |
| # Try lightweight text encoder first, fall back to full checkpoint | |
| try: | |
| print(f"Downloading text_encoder.pt from {repo_id}...") | |
| ckpt_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="text_encoder.pt", | |
| token=hf_token, | |
| ) | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| text_state = ckpt.get("state_dict", ckpt) | |
| embed_dim = ckpt.get("embed_dim", 256) | |
| text_only = True | |
| print(f" Text encoder loaded ({len(text_state)} keys, embed_dim={embed_dim})") | |
| except Exception as e: | |
| print(f"text_encoder.pt not found ({e}), falling back to full checkpoint...") | |
| ckpt_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="checkpoint_round_2.pt", | |
| token=hf_token, | |
| ) | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| raw_state = ckpt.get("state_dict", ckpt) | |
| text_state = _remap_flash_attn_keys( | |
| OrderedDict({k.replace("module.", ""): v for k, v in raw_state.items()}) | |
| ) | |
| embed_dim = 256 | |
| text_only = False | |
| print(f" Full checkpoint loaded (epoch {ckpt.get('epoch', '?')})") | |
| model = CLIP_VITL14( | |
| freeze_temperature=True, | |
| use_grad_checkpointing=False, | |
| context_length=77, | |
| vocab_size=49408, | |
| patch_dropout=0.0, | |
| num_frames=16, | |
| drop_path_rate=0.0, | |
| use_fast_conv1=True, | |
| use_flash_attn=False, | |
| use_quick_gelu=False, | |
| project_embed_dim=embed_dim, | |
| pretrain_zoo=None, | |
| ) | |
| missing, unexpected = model.load_state_dict(text_state, strict=False) | |
| non_vision_missing = [k for k in missing | |
| if not k.startswith("visual.") and k != "logit_scale"] | |
| if non_vision_missing: | |
| print(f" WARNING — text encoder keys missing: {non_vision_missing[:5]}") | |
| if not text_only and (missing or unexpected): | |
| print(f" Missing: {len(missing)}, Unexpected: {len(unexpected)}") | |
| model.eval() | |
| _model = model | |
| print(" Model ready.") | |
| return _model | |