Spaces:
Running on Zero
Running on Zero
Manmay Nakhashi commited on
Commit ·
8cd4942
1
Parent(s): 68ad4cc
Fetch unsloth bnb-4bit Gemma + Dramabox checkpoints from HF on startup
Browse filesReplaces hardcoded local paths (which only existed on the dev machine)
with model_downloader.get_all_paths(). Gemma now resolves to
unsloth/gemma-3-12b-it-bnb-4bit, matching the bnb_4bit=True path.
- app.py +6 -0
- src/inference_server.py +4 -2
app.py
CHANGED
|
@@ -16,11 +16,17 @@ import gradio as gr
|
|
| 16 |
# Local src import.
|
| 17 |
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
|
| 18 |
from inference_server import TTSServer # noqa: E402
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
|
|
|
|
|
| 22 |
logging.info("Loading DramaBox warm server (Gemma + DiT + VAE + Decoder)...")
|
| 23 |
tts = TTSServer(
|
|
|
|
|
|
|
|
|
|
| 24 |
device="cuda",
|
| 25 |
dtype=os.environ.get("LTX_DTYPE", "bf16"),
|
| 26 |
compile_model=os.environ.get("LTX_COMPILE", "0") == "1",
|
|
|
|
| 16 |
# Local src import.
|
| 17 |
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
|
| 18 |
from inference_server import TTSServer # noqa: E402
|
| 19 |
+
from model_downloader import get_all_paths # noqa: E402
|
| 20 |
|
| 21 |
|
| 22 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 23 |
+
logging.info("Fetching DramaBox checkpoints from HuggingFace (cached after first run)...")
|
| 24 |
+
paths = get_all_paths()
|
| 25 |
logging.info("Loading DramaBox warm server (Gemma + DiT + VAE + Decoder)...")
|
| 26 |
tts = TTSServer(
|
| 27 |
+
checkpoint=paths["transformer"],
|
| 28 |
+
full_checkpoint=paths["audio_components"],
|
| 29 |
+
gemma_root=paths["gemma_root"],
|
| 30 |
device="cuda",
|
| 31 |
dtype=os.environ.get("LTX_DTYPE", "bf16"),
|
| 32 |
compile_model=os.environ.get("LTX_COMPILE", "0") == "1",
|
src/inference_server.py
CHANGED
|
@@ -65,8 +65,10 @@ class TTSServer:
|
|
| 65 |
self.checkpoint = checkpoint or str(MODELS / "ltx-2.3-22b-dev-audio-only-v13-merged.safetensors")
|
| 66 |
self.full_checkpoint = full_checkpoint or os.environ.get(
|
| 67 |
"LTX_FULL_CHECKPOINT", "/mnt/persistent0/manmay/models/ltx23/ltx-2.3-22b-dev.safetensors")
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
self.device = torch.device(device)
|
| 71 |
self.dtype = torch.float16 if dtype == "fp16" else torch.bfloat16
|
| 72 |
self.compile_model = compile_model
|
|
|
|
| 65 |
self.checkpoint = checkpoint or str(MODELS / "ltx-2.3-22b-dev-audio-only-v13-merged.safetensors")
|
| 66 |
self.full_checkpoint = full_checkpoint or os.environ.get(
|
| 67 |
"LTX_FULL_CHECKPOINT", "/mnt/persistent0/manmay/models/ltx23/ltx-2.3-22b-dev.safetensors")
|
| 68 |
+
if gemma_root is None and not os.environ.get("GEMMA_DIR"):
|
| 69 |
+
from model_downloader import get_gemma_path
|
| 70 |
+
gemma_root = get_gemma_path()
|
| 71 |
+
self.gemma_root = gemma_root or os.environ["GEMMA_DIR"]
|
| 72 |
self.device = torch.device(device)
|
| 73 |
self.dtype = torch.float16 if dtype == "fp16" else torch.bfloat16
|
| 74 |
self.compile_model = compile_model
|