ek100-retrieval-demo / model_loader.py
jsurrea's picture
Update demo: text-to-narration retrieval, add precomputed assets
3671814 verified
"""
Loads the AVION text encoder from HF Hub.
In demo mode we only need the text encoder — video embeddings are pre-computed.
Singleton — loaded once per pod lifetime.
"""
import os
import sys
import torch
from collections import OrderedDict
from huggingface_hub import hf_hub_download
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from avion.models.model_clip import CLIP_VITL14
_model = None
# Maps FlashMHA key names (checkpoint saved with use_flash_attn=True)
# to standard nn.MultiheadAttention names.
_KEY_MAP = {
".attn.Wqkv.weight": ".attn.in_proj_weight",
".attn.Wqkv.bias": ".attn.in_proj_bias",
".mlp.fc1.weight": ".mlp.c_fc.weight",
".mlp.fc1.bias": ".mlp.c_fc.bias",
".mlp.fc2.weight": ".mlp.c_proj.weight",
".mlp.fc2.bias": ".mlp.c_proj.bias",
}
def _remap_flash_attn_keys(state_dict: OrderedDict) -> OrderedDict:
remapped = OrderedDict()
for k, v in state_dict.items():
new_k = k
for old_suffix, new_suffix in _KEY_MAP.items():
if k.endswith(old_suffix):
new_k = k[: -len(old_suffix)] + new_suffix
break
remapped[new_k] = v
return remapped
def get_model():
global _model
if _model is not None:
return _model
hf_token = os.environ.get("HF_TOKEN")
repo_id = os.environ.get("MODEL_REPO_ID", "jsurrea/avion-vitl-ek100-sms")
# Try lightweight text encoder first, fall back to full checkpoint
try:
print(f"Downloading text_encoder.pt from {repo_id}...")
ckpt_path = hf_hub_download(
repo_id=repo_id,
filename="text_encoder.pt",
token=hf_token,
)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
text_state = ckpt.get("state_dict", ckpt)
embed_dim = ckpt.get("embed_dim", 256)
text_only = True
print(f" Text encoder loaded ({len(text_state)} keys, embed_dim={embed_dim})")
except Exception as e:
print(f"text_encoder.pt not found ({e}), falling back to full checkpoint...")
ckpt_path = hf_hub_download(
repo_id=repo_id,
filename="checkpoint_round_2.pt",
token=hf_token,
)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
raw_state = ckpt.get("state_dict", ckpt)
text_state = _remap_flash_attn_keys(
OrderedDict({k.replace("module.", ""): v for k, v in raw_state.items()})
)
embed_dim = 256
text_only = False
print(f" Full checkpoint loaded (epoch {ckpt.get('epoch', '?')})")
model = CLIP_VITL14(
freeze_temperature=True,
use_grad_checkpointing=False,
context_length=77,
vocab_size=49408,
patch_dropout=0.0,
num_frames=16,
drop_path_rate=0.0,
use_fast_conv1=True,
use_flash_attn=False,
use_quick_gelu=False,
project_embed_dim=embed_dim,
pretrain_zoo=None,
)
missing, unexpected = model.load_state_dict(text_state, strict=False)
non_vision_missing = [k for k in missing
if not k.startswith("visual.") and k != "logit_scale"]
if non_vision_missing:
print(f" WARNING — text encoder keys missing: {non_vision_missing[:5]}")
if not text_only and (missing or unexpected):
print(f" Missing: {len(missing)}, Unexpected: {len(unexpected)}")
model.eval()
_model = model
print(" Model ready.")
return _model