File size: 2,821 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Local model loader — loads encoder models from local cache, falls back to HF.

Model directories (saved via model.save_pretrained()):
  dinov2-small/    — facebook/dinov2-small (21M params, 384-dim) vision
  vit-base/        — google/vit-base-patch16-224 (86M, 768-dim) vision fallback
  moonshine-base/  — UsefulSensors/moonshine-base (62M, 416-dim) audio
  pig-vae/         — Wan2.1 VAE checkpoint (84M params) video latent codec

Usage:
    from arbitor.encoders.models import load_encoder, load_processor
    
    model = load_encoder("dinov2-small")
    processor = load_processor("dinov2-small", "image")

Download models:
    python -m arbitor.encoders.models.download
"""
import os

_MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)))

# Map short names to (local_dir, hf_repo, type)
REGISTRY = {
    "dinov2-small": {
        "local": os.path.join(_MODELS_DIR, "dinov2-small"),
        "hf": "facebook/dinov2-small",
        "type": "auto",
    },
    "vit-base": {
        "local": os.path.join(_MODELS_DIR, "vit-base"),
        "hf": "google/vit-base-patch16-224",
        "type": "auto",
    },
    "moonshine-base": {
        "local": os.path.join(_MODELS_DIR, "moonshine-base"),
        "hf": "UsefulSensors/moonshine-base",
        "type": "auto",
    },
}


def resolve_path(name: str) -> tuple[str, dict]:
    """Return (local_path_or_hf_name, registry_entry)."""
    entry = REGISTRY.get(name)
    if entry is None:
        raise ValueError(f"Unknown model: {name}. Options: {list(REGISTRY.keys())}")
    if os.path.isdir(entry["local"]):
        return entry["local"], entry
    return entry["hf"], entry


def load_encoder(name: str, device=None, **kwargs):
    """Load model from local cache, falling back to HuggingFace.

    Args:
        name: Short name ("dinov2-small", "vit-base", "moonshine-base")
        device: Optional device to move model to (e.g. "cuda", "cpu")
    Returns:
        Loaded model in eval mode
    """
    from transformers import AutoModel

    path, entry = resolve_path(name)
    model = AutoModel.from_pretrained(path, low_cpu_mem_usage=True, **kwargs)
    model.eval()
    if device:
        model = model.to(device)
    return model


def load_processor(name: str, modality: str = "image"):
    """Load processor (image processor or feature extractor) from local cache.

    Args:
        name: Short model name
        modality: "image" for AutoImageProcessor, "audio" for AutoFeatureExtractor
    Returns:
        Processor instance
    """
    path, _ = resolve_path(name)
    if modality == "audio":
        from transformers import AutoFeatureExtractor
        return AutoFeatureExtractor.from_pretrained(path)
    else:
        from transformers import AutoImageProcessor
        return AutoImageProcessor.from_pretrained(path)