"""Optimized DramaBox inference: INT8 DiT + Gemma CPU offload. Drops denoising-phase VRAM from 17.4 GB to 5.9 GB, making DramaBox usable on 16 GB GPUs. Quality is preserved (MCD < 5.0 dB vs BF16 baseline). Usage: python inference_optimized.py --text "Hello world!" --output output.wav Requires the standard DramaBox installation plus `torchao`. """ import argparse import logging import re import sys import time import torch logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") def apply_int8_quantization(tts): """Apply selective INT8 weight-only quantization to the DiT.""" from torchao.quantization import quantize_, Int8WeightOnlyConfig attn_proj_keys = ("to_q", "to_k", "to_v", "to_out") FFN_START = 15 FFN_SKIP = {17} def dit_filter(mod, fqn): if not isinstance(mod, torch.nn.Linear): return False if "norm" in fqn: return False if "gate_logits" in fqn: return True if any(k in fqn for k in attn_proj_keys): return True if "audio_ff" in fqn: m = re.search(r"transformer_blocks\.(\d+)\.", fqn) if m: idx = int(m.group(1)) if "net.2" in fqn and idx >= FFN_START and idx not in FFN_SKIP: return True if "net.0.proj" in fqn: return True return False def io_filter(mod, fqn): return fqn in ("audio_patchify_proj", "audio_proj_out") and isinstance( mod, torch.nn.Linear ) logging.info("Applying INT8 quantization to DiT...") quantize_(tts._velocity_model, Int8WeightOnlyConfig(), filter_fn=dit_filter) quantize_(tts._velocity_model, Int8WeightOnlyConfig(), filter_fn=io_filter) vram = torch.cuda.memory_allocated(0) / (1024**3) logging.info(f"INT8 applied. VRAM: {vram:.2f} GB") def apply_gemma_offload(tts): """Patch PromptEncoder to offload Gemma 12B to CPU between uses.""" pe = tts._prompt_encoder pe_cls = type(pe) orig_call = pe_cls.__call__ def _offload_call(self, prompts, **kwargs): if not (self._warm and self._warm_text_encoder is not None): return orig_call(self, prompts, **kwargs) te = self._warm_text_encoder is_on_cpu = next(te.parameters()).device.type == "cpu" if is_on_cpu: logging.info("Moving Gemma to GPU...") te.to(tts.device) raw_outputs = [te.encode(p) for p in prompts] logging.info("Offloading Gemma to CPU...") te.to("cpu") torch.cuda.empty_cache() ep = self._warm_embeddings_processor return [ep.process_hidden_states(hs, mask) for hs, mask in raw_outputs] pe_cls.__call__ = _offload_call logging.info("Gemma CPU offload enabled") def main(): parser = argparse.ArgumentParser(description="DramaBox INT8 optimized inference") parser.add_argument("--text", required=True, help="Text to synthesize") parser.add_argument("--output", default="output.wav", help="Output WAV path") parser.add_argument( "--no-offload", action="store_true", help="Disable Gemma CPU offload" ) parser.add_argument( "--no-quantize", action="store_true", help="Disable INT8 quantization" ) args = parser.parse_args() sys.path.insert(0, "/app/dramabox/src") from model_downloader import get_all_paths from inference_server import TTSServer paths = get_all_paths() logging.info("Loading DramaBox...") tts = TTSServer( checkpoint=paths["transformer"], full_checkpoint=paths["audio_components"], gemma_root=paths["gemma_root"], device="cuda", dtype="bf16", compile_model=False, bnb_4bit=True, ) if not args.no_quantize: apply_int8_quantization(tts) if not args.no_offload: apply_gemma_offload(tts) logging.info(f"Generating: {args.text[:80]}...") t0 = time.time() result = tts.generate( prompts=[{"text": args.text}], return_type="file", output_path=args.output, ) elapsed = time.time() - t0 peak_vram = torch.cuda.max_memory_allocated(0) / (1024**3) logging.info( f"Done in {elapsed:.2f}s | Peak VRAM: {peak_vram:.2f} GB | Saved: {args.output}" ) if __name__ == "__main__": main()