| """Optimized DramaBox inference: INT8 DiT + Gemma CPU offload. |
| |
| Drops denoising-phase VRAM from 17.4 GB to 5.9 GB, making DramaBox usable |
| on 16 GB GPUs. Quality is preserved (MCD < 5.0 dB vs BF16 baseline). |
| |
| Usage: |
| python inference_optimized.py --text "Hello world!" --output output.wav |
| |
| Requires the standard DramaBox installation plus `torchao`. |
| """ |
| import argparse |
| import logging |
| import re |
| import sys |
| import time |
|
|
| import torch |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
|
|
|
|
| def apply_int8_quantization(tts): |
| """Apply selective INT8 weight-only quantization to the DiT.""" |
| from torchao.quantization import quantize_, Int8WeightOnlyConfig |
|
|
| attn_proj_keys = ("to_q", "to_k", "to_v", "to_out") |
| FFN_START = 15 |
| FFN_SKIP = {17} |
|
|
| def dit_filter(mod, fqn): |
| if not isinstance(mod, torch.nn.Linear): |
| return False |
| if "norm" in fqn: |
| return False |
| if "gate_logits" in fqn: |
| return True |
| if any(k in fqn for k in attn_proj_keys): |
| return True |
| if "audio_ff" in fqn: |
| m = re.search(r"transformer_blocks\.(\d+)\.", fqn) |
| if m: |
| idx = int(m.group(1)) |
| if "net.2" in fqn and idx >= FFN_START and idx not in FFN_SKIP: |
| return True |
| if "net.0.proj" in fqn: |
| return True |
| return False |
|
|
| def io_filter(mod, fqn): |
| return fqn in ("audio_patchify_proj", "audio_proj_out") and isinstance( |
| mod, torch.nn.Linear |
| ) |
|
|
| logging.info("Applying INT8 quantization to DiT...") |
| quantize_(tts._velocity_model, Int8WeightOnlyConfig(), filter_fn=dit_filter) |
| quantize_(tts._velocity_model, Int8WeightOnlyConfig(), filter_fn=io_filter) |
|
|
| vram = torch.cuda.memory_allocated(0) / (1024**3) |
| logging.info(f"INT8 applied. VRAM: {vram:.2f} GB") |
|
|
|
|
| def apply_gemma_offload(tts): |
| """Patch PromptEncoder to offload Gemma 12B to CPU between uses.""" |
| pe = tts._prompt_encoder |
| pe_cls = type(pe) |
| orig_call = pe_cls.__call__ |
|
|
| def _offload_call(self, prompts, **kwargs): |
| if not (self._warm and self._warm_text_encoder is not None): |
| return orig_call(self, prompts, **kwargs) |
|
|
| te = self._warm_text_encoder |
| is_on_cpu = next(te.parameters()).device.type == "cpu" |
|
|
| if is_on_cpu: |
| logging.info("Moving Gemma to GPU...") |
| te.to(tts.device) |
|
|
| raw_outputs = [te.encode(p) for p in prompts] |
|
|
| logging.info("Offloading Gemma to CPU...") |
| te.to("cpu") |
| torch.cuda.empty_cache() |
|
|
| ep = self._warm_embeddings_processor |
| return [ep.process_hidden_states(hs, mask) for hs, mask in raw_outputs] |
|
|
| pe_cls.__call__ = _offload_call |
| logging.info("Gemma CPU offload enabled") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="DramaBox INT8 optimized inference") |
| parser.add_argument("--text", required=True, help="Text to synthesize") |
| parser.add_argument("--output", default="output.wav", help="Output WAV path") |
| parser.add_argument( |
| "--no-offload", action="store_true", help="Disable Gemma CPU offload" |
| ) |
| parser.add_argument( |
| "--no-quantize", action="store_true", help="Disable INT8 quantization" |
| ) |
| args = parser.parse_args() |
|
|
| sys.path.insert(0, "/app/dramabox/src") |
| from model_downloader import get_all_paths |
| from inference_server import TTSServer |
|
|
| paths = get_all_paths() |
| logging.info("Loading DramaBox...") |
| tts = TTSServer( |
| checkpoint=paths["transformer"], |
| full_checkpoint=paths["audio_components"], |
| gemma_root=paths["gemma_root"], |
| device="cuda", |
| dtype="bf16", |
| compile_model=False, |
| bnb_4bit=True, |
| ) |
|
|
| if not args.no_quantize: |
| apply_int8_quantization(tts) |
|
|
| if not args.no_offload: |
| apply_gemma_offload(tts) |
|
|
| logging.info(f"Generating: {args.text[:80]}...") |
| t0 = time.time() |
| result = tts.generate( |
| prompts=[{"text": args.text}], |
| return_type="file", |
| output_path=args.output, |
| ) |
| elapsed = time.time() - t0 |
|
|
| peak_vram = torch.cuda.max_memory_allocated(0) / (1024**3) |
| logging.info( |
| f"Done in {elapsed:.2f}s | Peak VRAM: {peak_vram:.2f} GB | Saved: {args.output}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|