File size: 4,402 Bytes
5e19135 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | """Optimized DramaBox inference: INT8 DiT + Gemma CPU offload.
Drops denoising-phase VRAM from 17.4 GB to 5.9 GB, making DramaBox usable
on 16 GB GPUs. Quality is preserved (MCD < 5.0 dB vs BF16 baseline).
Usage:
python inference_optimized.py --text "Hello world!" --output output.wav
Requires the standard DramaBox installation plus `torchao`.
"""
import argparse
import logging
import re
import sys
import time
import torch
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
def apply_int8_quantization(tts):
"""Apply selective INT8 weight-only quantization to the DiT."""
from torchao.quantization import quantize_, Int8WeightOnlyConfig
attn_proj_keys = ("to_q", "to_k", "to_v", "to_out")
FFN_START = 15
FFN_SKIP = {17}
def dit_filter(mod, fqn):
if not isinstance(mod, torch.nn.Linear):
return False
if "norm" in fqn:
return False
if "gate_logits" in fqn:
return True
if any(k in fqn for k in attn_proj_keys):
return True
if "audio_ff" in fqn:
m = re.search(r"transformer_blocks\.(\d+)\.", fqn)
if m:
idx = int(m.group(1))
if "net.2" in fqn and idx >= FFN_START and idx not in FFN_SKIP:
return True
if "net.0.proj" in fqn:
return True
return False
def io_filter(mod, fqn):
return fqn in ("audio_patchify_proj", "audio_proj_out") and isinstance(
mod, torch.nn.Linear
)
logging.info("Applying INT8 quantization to DiT...")
quantize_(tts._velocity_model, Int8WeightOnlyConfig(), filter_fn=dit_filter)
quantize_(tts._velocity_model, Int8WeightOnlyConfig(), filter_fn=io_filter)
vram = torch.cuda.memory_allocated(0) / (1024**3)
logging.info(f"INT8 applied. VRAM: {vram:.2f} GB")
def apply_gemma_offload(tts):
"""Patch PromptEncoder to offload Gemma 12B to CPU between uses."""
pe = tts._prompt_encoder
pe_cls = type(pe)
orig_call = pe_cls.__call__
def _offload_call(self, prompts, **kwargs):
if not (self._warm and self._warm_text_encoder is not None):
return orig_call(self, prompts, **kwargs)
te = self._warm_text_encoder
is_on_cpu = next(te.parameters()).device.type == "cpu"
if is_on_cpu:
logging.info("Moving Gemma to GPU...")
te.to(tts.device)
raw_outputs = [te.encode(p) for p in prompts]
logging.info("Offloading Gemma to CPU...")
te.to("cpu")
torch.cuda.empty_cache()
ep = self._warm_embeddings_processor
return [ep.process_hidden_states(hs, mask) for hs, mask in raw_outputs]
pe_cls.__call__ = _offload_call
logging.info("Gemma CPU offload enabled")
def main():
parser = argparse.ArgumentParser(description="DramaBox INT8 optimized inference")
parser.add_argument("--text", required=True, help="Text to synthesize")
parser.add_argument("--output", default="output.wav", help="Output WAV path")
parser.add_argument(
"--no-offload", action="store_true", help="Disable Gemma CPU offload"
)
parser.add_argument(
"--no-quantize", action="store_true", help="Disable INT8 quantization"
)
args = parser.parse_args()
sys.path.insert(0, "/app/dramabox/src")
from model_downloader import get_all_paths
from inference_server import TTSServer
paths = get_all_paths()
logging.info("Loading DramaBox...")
tts = TTSServer(
checkpoint=paths["transformer"],
full_checkpoint=paths["audio_components"],
gemma_root=paths["gemma_root"],
device="cuda",
dtype="bf16",
compile_model=False,
bnb_4bit=True,
)
if not args.no_quantize:
apply_int8_quantization(tts)
if not args.no_offload:
apply_gemma_offload(tts)
logging.info(f"Generating: {args.text[:80]}...")
t0 = time.time()
result = tts.generate(
prompts=[{"text": args.text}],
return_type="file",
output_path=args.output,
)
elapsed = time.time() - t0
peak_vram = torch.cuda.max_memory_allocated(0) / (1024**3)
logging.info(
f"Done in {elapsed:.2f}s | Peak VRAM: {peak_vram:.2f} GB | Saved: {args.output}"
)
if __name__ == "__main__":
main()
|