Manmay Nakhashi commited on
Commit
8cd4942
·
1 Parent(s): 68ad4cc

Fetch unsloth bnb-4bit Gemma + Dramabox checkpoints from HF on startup

Browse files

Replaces hardcoded local paths (which only existed on the dev machine)
with model_downloader.get_all_paths(). Gemma now resolves to
unsloth/gemma-3-12b-it-bnb-4bit, matching the bnb_4bit=True path.

Files changed (2) hide show
  1. app.py +6 -0
  2. src/inference_server.py +4 -2
app.py CHANGED
@@ -16,11 +16,17 @@ import gradio as gr
16
  # Local src import.
17
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
18
  from inference_server import TTSServer # noqa: E402
 
19
 
20
 
21
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
 
 
22
  logging.info("Loading DramaBox warm server (Gemma + DiT + VAE + Decoder)...")
23
  tts = TTSServer(
 
 
 
24
  device="cuda",
25
  dtype=os.environ.get("LTX_DTYPE", "bf16"),
26
  compile_model=os.environ.get("LTX_COMPILE", "0") == "1",
 
16
  # Local src import.
17
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
18
  from inference_server import TTSServer # noqa: E402
19
+ from model_downloader import get_all_paths # noqa: E402
20
 
21
 
22
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
23
+ logging.info("Fetching DramaBox checkpoints from HuggingFace (cached after first run)...")
24
+ paths = get_all_paths()
25
  logging.info("Loading DramaBox warm server (Gemma + DiT + VAE + Decoder)...")
26
  tts = TTSServer(
27
+ checkpoint=paths["transformer"],
28
+ full_checkpoint=paths["audio_components"],
29
+ gemma_root=paths["gemma_root"],
30
  device="cuda",
31
  dtype=os.environ.get("LTX_DTYPE", "bf16"),
32
  compile_model=os.environ.get("LTX_COMPILE", "0") == "1",
src/inference_server.py CHANGED
@@ -65,8 +65,10 @@ class TTSServer:
65
  self.checkpoint = checkpoint or str(MODELS / "ltx-2.3-22b-dev-audio-only-v13-merged.safetensors")
66
  self.full_checkpoint = full_checkpoint or os.environ.get(
67
  "LTX_FULL_CHECKPOINT", "/mnt/persistent0/manmay/models/ltx23/ltx-2.3-22b-dev.safetensors")
68
- self.gemma_root = gemma_root or os.environ.get(
69
- "GEMMA_DIR", "/mnt/persistent0/manmay/models/gemma-3-12b-it-qat-q4_0-unquantized")
 
 
70
  self.device = torch.device(device)
71
  self.dtype = torch.float16 if dtype == "fp16" else torch.bfloat16
72
  self.compile_model = compile_model
 
65
  self.checkpoint = checkpoint or str(MODELS / "ltx-2.3-22b-dev-audio-only-v13-merged.safetensors")
66
  self.full_checkpoint = full_checkpoint or os.environ.get(
67
  "LTX_FULL_CHECKPOINT", "/mnt/persistent0/manmay/models/ltx23/ltx-2.3-22b-dev.safetensors")
68
+ if gemma_root is None and not os.environ.get("GEMMA_DIR"):
69
+ from model_downloader import get_gemma_path
70
+ gemma_root = get_gemma_path()
71
+ self.gemma_root = gemma_root or os.environ["GEMMA_DIR"]
72
  self.device = torch.device(device)
73
  self.dtype = torch.float16 if dtype == "fp16" else torch.bfloat16
74
  self.compile_model = compile_model