File size: 4,724 Bytes
08c5e28
 
fdc2b0b
08c5e28
 
 
 
 
 
 
 
 
 
 
 
fdc2b0b
08c5e28
7e0eb32
08c5e28
 
7e0eb32
08c5e28
fdc2b0b
08c5e28
 
 
 
 
 
 
 
 
 
 
 
fdc2b0b
08c5e28
 
 
 
 
 
 
 
 
 
fdc2b0b
08c5e28
 
fdc2b0b
08c5e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e0eb32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08c5e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python3
"""
Download Dramabox models from HuggingFace.

Models are cached locally after first download.
Gemma text encoder is fetched separately from Google's repo.
"""
import logging
import os
from pathlib import Path

from huggingface_hub import hf_hub_download, snapshot_download

logger = logging.getLogger(__name__)

DRAMABOX_REPO = "ResembleAI/Dramabox"
GEMMA_REPO = "unsloth/gemma-3-12b-it-bnb-4bit"
REUSE_REPO = "nvidia/RE-USE"

# Default cache directory
DEFAULT_CACHE = os.path.join(os.environ.get("HF_HOME", os.path.expanduser("~")), ".cache", "dramabox")

# Model files in the HF repo (flat structure)
MODEL_FILES = {
    "transformer": "dramabox-dit-v1.safetensors",
    "audio_components": "dramabox-audio-components.safetensors",
    "silence_latent": "assets/silence_latent_frame.pt",
}


def get_model_path(name: str, cache_dir: str = None) -> str:
    """Download a model file from HF and return local path.

    Args:
        name: One of 'transformer', 'audio_components', 'silence_latent'
        cache_dir: Local cache directory (default: ~/.cache/dramabox)

    Returns:
        Local file path
    """
    cache_dir = cache_dir or DEFAULT_CACHE

    if name not in MODEL_FILES:
        raise ValueError(f"Unknown model: {name}. Choose from: {list(MODEL_FILES.keys())}")

    repo_path = MODEL_FILES[name]
    logger.info(f"Fetching {name} from {DRAMABOX_REPO}/{repo_path}...")

    local_path = hf_hub_download(
        repo_id=DRAMABOX_REPO,
        filename=repo_path,
        cache_dir=cache_dir,
        token=os.environ.get("HF_TOKEN"),
    )
    logger.info(f"  -> {local_path}")
    return local_path


def get_gemma_path(cache_dir: str = None) -> str:
    """Download Gemma 3 12B IT (pre-quantized bnb-4bit via unsloth) and return
    the snapshot directory. Using the pre-quantized variant skips runtime
    bitsandbytes quantization and ~halves the Gemma load time.
    """
    cache_dir = cache_dir or DEFAULT_CACHE
    logger.info(f"Fetching Gemma from {GEMMA_REPO}...")

    local_dir = snapshot_download(
        repo_id=GEMMA_REPO,
        cache_dir=cache_dir,
        token=os.environ.get("HF_TOKEN"),
    )
    logger.info(f"  -> {local_dir}")
    return local_dir


def get_reuse_code_path(cache_dir: str = None) -> str:
    """Fetch the nvidia/RE-USE code + configs needed by REUSEUpsampler.

    Only the .py / .yaml / .json files are pulled (~150 KB) — the 38 MB
    ``model.safetensors`` is intentionally skipped because
    ``SEMamba.from_pretrained("nvidia/RE-USE", ...)`` re-downloads weights
    through the standard HF cache on first instantiation, so vendoring them
    here would just duplicate ~38 MB on disk.

    Honors $REUSE_DIR for a pre-vendored copy (e.g. ``third_party/RE-USE/``):
    if set and exists, that path is returned without touching the network.
    Falls back to ``third_party/RE-USE/`` if it already contains the model
    file, otherwise snapshot-downloads into the dramabox cache.
    """
    env_dir = os.environ.get("REUSE_DIR")
    if env_dir and Path(env_dir).is_dir():
        return env_dir

    repo_root = Path(__file__).resolve().parent.parent
    local_vendor = repo_root / "third_party" / "RE-USE"
    if (local_vendor / "models" / "generator_SEMamba_time_d4.py").is_file():
        return str(local_vendor)

    cache_dir = cache_dir or DEFAULT_CACHE
    logger.info(f"Fetching RE-USE code/configs from {REUSE_REPO}...")
    local_dir = snapshot_download(
        repo_id=REUSE_REPO,
        cache_dir=cache_dir,
        token=os.environ.get("HF_TOKEN"),
        allow_patterns=["*.py", "*.yaml", "*.json",
                        "recipes/*", "models/*.py", "utils/*.py"],
    )
    logger.info(f"  -> {local_dir}")
    return local_dir


def get_all_paths(cache_dir: str = None) -> dict:
    """Download all required models and return paths dict.

    Returns:
        {
            'transformer': '/path/to/transformer.safetensors',
            'audio_components': '/path/to/audio-components.safetensors',
            'silence_latent': '/path/to/silence_latent_frame.pt',
            'gemma_root': '/path/to/unsloth/gemma-3-12b-it-bnb-4bit/',
        }
    """
    cache_dir = cache_dir or DEFAULT_CACHE
    paths = {}

    for name in MODEL_FILES:
        paths[name] = get_model_path(name, cache_dir)

    paths["gemma_root"] = get_gemma_path(cache_dir)
    return paths


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    paths = get_all_paths()
    print("\nAll models downloaded:")
    for k, v in paths.items():
        size = os.path.getsize(v) / 1e9 if os.path.isfile(v) else "dir"
        print(f"  {k}: {v} ({size:.2f}GB)" if isinstance(size, float) else f"  {k}: {v} (directory)")