File size: 4,402 Bytes
5e19135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Optimized DramaBox inference: INT8 DiT + Gemma CPU offload.

Drops denoising-phase VRAM from 17.4 GB to 5.9 GB, making DramaBox usable
on 16 GB GPUs. Quality is preserved (MCD < 5.0 dB vs BF16 baseline).

Usage:
    python inference_optimized.py --text "Hello world!" --output output.wav

Requires the standard DramaBox installation plus `torchao`.
"""
import argparse
import logging
import re
import sys
import time

import torch

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")


def apply_int8_quantization(tts):
    """Apply selective INT8 weight-only quantization to the DiT."""
    from torchao.quantization import quantize_, Int8WeightOnlyConfig

    attn_proj_keys = ("to_q", "to_k", "to_v", "to_out")
    FFN_START = 15
    FFN_SKIP = {17}

    def dit_filter(mod, fqn):
        if not isinstance(mod, torch.nn.Linear):
            return False
        if "norm" in fqn:
            return False
        if "gate_logits" in fqn:
            return True
        if any(k in fqn for k in attn_proj_keys):
            return True
        if "audio_ff" in fqn:
            m = re.search(r"transformer_blocks\.(\d+)\.", fqn)
            if m:
                idx = int(m.group(1))
                if "net.2" in fqn and idx >= FFN_START and idx not in FFN_SKIP:
                    return True
                if "net.0.proj" in fqn:
                    return True
        return False

    def io_filter(mod, fqn):
        return fqn in ("audio_patchify_proj", "audio_proj_out") and isinstance(
            mod, torch.nn.Linear
        )

    logging.info("Applying INT8 quantization to DiT...")
    quantize_(tts._velocity_model, Int8WeightOnlyConfig(), filter_fn=dit_filter)
    quantize_(tts._velocity_model, Int8WeightOnlyConfig(), filter_fn=io_filter)

    vram = torch.cuda.memory_allocated(0) / (1024**3)
    logging.info(f"INT8 applied. VRAM: {vram:.2f} GB")


def apply_gemma_offload(tts):
    """Patch PromptEncoder to offload Gemma 12B to CPU between uses."""
    pe = tts._prompt_encoder
    pe_cls = type(pe)
    orig_call = pe_cls.__call__

    def _offload_call(self, prompts, **kwargs):
        if not (self._warm and self._warm_text_encoder is not None):
            return orig_call(self, prompts, **kwargs)

        te = self._warm_text_encoder
        is_on_cpu = next(te.parameters()).device.type == "cpu"

        if is_on_cpu:
            logging.info("Moving Gemma to GPU...")
            te.to(tts.device)

        raw_outputs = [te.encode(p) for p in prompts]

        logging.info("Offloading Gemma to CPU...")
        te.to("cpu")
        torch.cuda.empty_cache()

        ep = self._warm_embeddings_processor
        return [ep.process_hidden_states(hs, mask) for hs, mask in raw_outputs]

    pe_cls.__call__ = _offload_call
    logging.info("Gemma CPU offload enabled")


def main():
    parser = argparse.ArgumentParser(description="DramaBox INT8 optimized inference")
    parser.add_argument("--text", required=True, help="Text to synthesize")
    parser.add_argument("--output", default="output.wav", help="Output WAV path")
    parser.add_argument(
        "--no-offload", action="store_true", help="Disable Gemma CPU offload"
    )
    parser.add_argument(
        "--no-quantize", action="store_true", help="Disable INT8 quantization"
    )
    args = parser.parse_args()

    sys.path.insert(0, "/app/dramabox/src")
    from model_downloader import get_all_paths
    from inference_server import TTSServer

    paths = get_all_paths()
    logging.info("Loading DramaBox...")
    tts = TTSServer(
        checkpoint=paths["transformer"],
        full_checkpoint=paths["audio_components"],
        gemma_root=paths["gemma_root"],
        device="cuda",
        dtype="bf16",
        compile_model=False,
        bnb_4bit=True,
    )

    if not args.no_quantize:
        apply_int8_quantization(tts)

    if not args.no_offload:
        apply_gemma_offload(tts)

    logging.info(f"Generating: {args.text[:80]}...")
    t0 = time.time()
    result = tts.generate(
        prompts=[{"text": args.text}],
        return_type="file",
        output_path=args.output,
    )
    elapsed = time.time() - t0

    peak_vram = torch.cuda.max_memory_allocated(0) / (1024**3)
    logging.info(
        f"Done in {elapsed:.2f}s | Peak VRAM: {peak_vram:.2f} GB | Saved: {args.output}"
    )


if __name__ == "__main__":
    main()