File size: 1,970 Bytes
6f0e045
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
extract_head.py
===============
Run this ONCE on your local machine (where torch is installed):

    cd D:\CoE\deploy
    python extract_head.py

Reads best_model_phase1.pt (1.1 GB) and saves ONLY the fine-tuned layers:
  - fusion.*          (attention + FFN + norms)     ~12 MB
  - classifier.*      (final classification head)
  - uncertainty_head.*
  - *_proj.*          (lightweight projection adapters)

These total ~25 MB — well within HF's 1 GB limit.
The four backbone encoders (CLIP, ViT, ResNet, EfficientNet) are NOT saved
because app.py downloads them from HF Hub at runtime for free.
"""

import torch, os

CHECKPOINT = os.path.join(
    os.path.dirname(__file__),
    "..", "universal_vision_checkpoints", "best_model_phase1.pt"
)
OUTPUT = os.path.join(os.path.dirname(__file__), "head_weights.pt")

print(f"Loading: {os.path.abspath(CHECKPOINT)}")
ckpt = torch.load(CHECKPOINT, map_location="cpu", weights_only=False)
state = ckpt.get("model_state_dict", ckpt)

# These are the BACKBONE prefixes — we drop them (loaded from HF Hub instead)
BACKBONE_PREFIXES = ("clip_model.", "vit.", "resnet.", "efficientnet.")

head_state = {
    k: v for k, v in state.items()
    if not any(k.startswith(p) for p in BACKBONE_PREFIXES)
}

total_mb = sum(v.numel() * v.element_size() for v in state.values()) / 1024**2
head_mb  = sum(v.numel() * v.element_size() for v in head_state.values()) / 1024**2

print(f"\nFull checkpoint : {total_mb:.1f} MB  ({len(state)} tensors)")
print(f"Head only       : {head_mb:.2f} MB  ({len(head_state)} tensors)")
print("\nSaved keys:")
for k, v in head_state.items():
    kb = v.numel() * v.element_size() / 1024
    print(f"  {k:55s}  {str(tuple(v.shape)):25s}  {kb:.1f} KB")

torch.save({"model_state_dict": head_state}, OUTPUT)
print(f"\n✅ Saved to: {os.path.abspath(OUTPUT)}")
print(f"   Size: {os.path.getsize(OUTPUT)/1024**2:.2f} MB")
print("\nNext step: push head_weights.pt to your HF Space repo (no LFS needed).")