File size: 3,191 Bytes
08c5e28 fdc2b0b 08c5e28 fdc2b0b 08c5e28 fdc2b0b 08c5e28 fdc2b0b 08c5e28 fdc2b0b 08c5e28 fdc2b0b 08c5e28 fdc2b0b 08c5e28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | #!/usr/bin/env python3
"""
Download Dramabox models from HuggingFace.
Models are cached locally after first download.
Gemma text encoder is fetched separately from Google's repo.
"""
import logging
import os
from pathlib import Path
from huggingface_hub import hf_hub_download, snapshot_download
logger = logging.getLogger(__name__)
DRAMABOX_REPO = "ResembleAI/Dramabox"
GEMMA_REPO = "unsloth/gemma-3-12b-it-bnb-4bit"
# Default cache directory
DEFAULT_CACHE = os.environ.get(
"DRAMABOX_CACHE",
os.path.join(os.path.expanduser("~"), ".cache", "dramabox"),
)
# Model files in the HF repo (flat structure)
MODEL_FILES = {
"transformer": "dramabox-dit-v1.safetensors",
"audio_components": "dramabox-audio-components.safetensors",
"silence_latent": "assets/silence_latent_frame.pt",
}
def get_model_path(name: str, cache_dir: str = None) -> str:
"""Download a model file from HF and return local path.
Args:
name: One of 'transformer', 'audio_components', 'silence_latent'
cache_dir: Local cache directory (default: ~/.cache/dramabox)
Returns:
Local file path
"""
cache_dir = cache_dir or DEFAULT_CACHE
if name not in MODEL_FILES:
raise ValueError(f"Unknown model: {name}. Choose from: {list(MODEL_FILES.keys())}")
repo_path = MODEL_FILES[name]
logger.info(f"Fetching {name} from {DRAMABOX_REPO}/{repo_path}...")
local_path = hf_hub_download(
repo_id=DRAMABOX_REPO,
filename=repo_path,
cache_dir=cache_dir,
token=os.environ.get("HF_TOKEN"),
)
logger.info(f" -> {local_path}")
return local_path
def get_gemma_path(cache_dir: str = None) -> str:
"""Download Gemma 3 12B IT (pre-quantized bnb-4bit via unsloth) and return
the snapshot directory. Using the pre-quantized variant skips runtime
bitsandbytes quantization and ~halves the Gemma load time.
"""
cache_dir = cache_dir or DEFAULT_CACHE
logger.info(f"Fetching Gemma from {GEMMA_REPO}...")
local_dir = snapshot_download(
repo_id=GEMMA_REPO,
cache_dir=cache_dir,
token=os.environ.get("HF_TOKEN"),
)
logger.info(f" -> {local_dir}")
return local_dir
def get_all_paths(cache_dir: str = None) -> dict:
"""Download all required models and return paths dict.
Returns:
{
'transformer': '/path/to/transformer.safetensors',
'audio_components': '/path/to/audio-components.safetensors',
'silence_latent': '/path/to/silence_latent_frame.pt',
'gemma_root': '/path/to/unsloth/gemma-3-12b-it-bnb-4bit/',
}
"""
cache_dir = cache_dir or DEFAULT_CACHE
paths = {}
for name in MODEL_FILES:
paths[name] = get_model_path(name, cache_dir)
paths["gemma_root"] = get_gemma_path(cache_dir)
return paths
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
paths = get_all_paths()
print("\nAll models downloaded:")
for k, v in paths.items():
size = os.path.getsize(v) / 1e9 if os.path.isfile(v) else "dir"
print(f" {k}: {v} ({size:.2f}GB)" if isinstance(size, float) else f" {k}: {v} (directory)")
|