dramabox-dit-int8 / inference_optimized.py
EllaPriest45's picture
Duplicate from moe2382/dramabox-dit-int8
5e19135
"""Optimized DramaBox inference: INT8 DiT + Gemma CPU offload.
Drops denoising-phase VRAM from 17.4 GB to 5.9 GB, making DramaBox usable
on 16 GB GPUs. Quality is preserved (MCD < 5.0 dB vs BF16 baseline).
Usage:
python inference_optimized.py --text "Hello world!" --output output.wav
Requires the standard DramaBox installation plus `torchao`.
"""
import argparse
import logging
import re
import sys
import time
import torch
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
def apply_int8_quantization(tts):
"""Apply selective INT8 weight-only quantization to the DiT."""
from torchao.quantization import quantize_, Int8WeightOnlyConfig
attn_proj_keys = ("to_q", "to_k", "to_v", "to_out")
FFN_START = 15
FFN_SKIP = {17}
def dit_filter(mod, fqn):
if not isinstance(mod, torch.nn.Linear):
return False
if "norm" in fqn:
return False
if "gate_logits" in fqn:
return True
if any(k in fqn for k in attn_proj_keys):
return True
if "audio_ff" in fqn:
m = re.search(r"transformer_blocks\.(\d+)\.", fqn)
if m:
idx = int(m.group(1))
if "net.2" in fqn and idx >= FFN_START and idx not in FFN_SKIP:
return True
if "net.0.proj" in fqn:
return True
return False
def io_filter(mod, fqn):
return fqn in ("audio_patchify_proj", "audio_proj_out") and isinstance(
mod, torch.nn.Linear
)
logging.info("Applying INT8 quantization to DiT...")
quantize_(tts._velocity_model, Int8WeightOnlyConfig(), filter_fn=dit_filter)
quantize_(tts._velocity_model, Int8WeightOnlyConfig(), filter_fn=io_filter)
vram = torch.cuda.memory_allocated(0) / (1024**3)
logging.info(f"INT8 applied. VRAM: {vram:.2f} GB")
def apply_gemma_offload(tts):
"""Patch PromptEncoder to offload Gemma 12B to CPU between uses."""
pe = tts._prompt_encoder
pe_cls = type(pe)
orig_call = pe_cls.__call__
def _offload_call(self, prompts, **kwargs):
if not (self._warm and self._warm_text_encoder is not None):
return orig_call(self, prompts, **kwargs)
te = self._warm_text_encoder
is_on_cpu = next(te.parameters()).device.type == "cpu"
if is_on_cpu:
logging.info("Moving Gemma to GPU...")
te.to(tts.device)
raw_outputs = [te.encode(p) for p in prompts]
logging.info("Offloading Gemma to CPU...")
te.to("cpu")
torch.cuda.empty_cache()
ep = self._warm_embeddings_processor
return [ep.process_hidden_states(hs, mask) for hs, mask in raw_outputs]
pe_cls.__call__ = _offload_call
logging.info("Gemma CPU offload enabled")
def main():
parser = argparse.ArgumentParser(description="DramaBox INT8 optimized inference")
parser.add_argument("--text", required=True, help="Text to synthesize")
parser.add_argument("--output", default="output.wav", help="Output WAV path")
parser.add_argument(
"--no-offload", action="store_true", help="Disable Gemma CPU offload"
)
parser.add_argument(
"--no-quantize", action="store_true", help="Disable INT8 quantization"
)
args = parser.parse_args()
sys.path.insert(0, "/app/dramabox/src")
from model_downloader import get_all_paths
from inference_server import TTSServer
paths = get_all_paths()
logging.info("Loading DramaBox...")
tts = TTSServer(
checkpoint=paths["transformer"],
full_checkpoint=paths["audio_components"],
gemma_root=paths["gemma_root"],
device="cuda",
dtype="bf16",
compile_model=False,
bnb_4bit=True,
)
if not args.no_quantize:
apply_int8_quantization(tts)
if not args.no_offload:
apply_gemma_offload(tts)
logging.info(f"Generating: {args.text[:80]}...")
t0 = time.time()
result = tts.generate(
prompts=[{"text": args.text}],
return_type="file",
output_path=args.output,
)
elapsed = time.time() - t0
peak_vram = torch.cuda.max_memory_allocated(0) / (1024**3)
logging.info(
f"Done in {elapsed:.2f}s | Peak VRAM: {peak_vram:.2f} GB | Saved: {args.output}"
)
if __name__ == "__main__":
main()