dramabox-dit-int8 / README.md
EllaPriest45's picture
Duplicate from moe2382/dramabox-dit-int8
5e19135
---
license: other
license_name: ltx-2-community
license_link: https://huggingface.co/ResembleAI/Dramabox/blob/main/LICENSE
base_model: ResembleAI/Dramabox
tags:
- tts
- text-to-speech
- audio
- quantized
- int8
- dramabox
- torchao
- diffusion-transformer
- flow-matching
library_name: pytorch
pipeline_tag: text-to-speech
---
# DramaBox DiT INT8 β€” Selective Weight-Only Quantization
A selectively quantized version of the [DramaBox TTS](https://huggingface.co/ResembleAI/Dramabox) 3.3B DiT (Diffusion Transformer) model from [Resemble AI](https://huggingface.co/ResembleAI). Reduces VRAM by 20% and checkpoint size by 45% while preserving audio quality.
> **Base model:** [ResembleAI/Dramabox](https://huggingface.co/ResembleAI/Dramabox) | **Code:** [resemble-ai/DramaBox](https://github.com/resemble-ai/DramaBox) | **Architecture:** LTX-2.3 DiT + Gemma 3 12B
## What's included
| File | Size | Description |
|------|------|-------------|
| `dramabox-dit-int8-selective.safetensors` | 3.37 GB | Quantized DiT weights (INT8 data + BF16 scales) |
| `config.json` | 28 KB | Layer map: which 562 layers are quantized |
| `load_int8.py` | 3.6 KB | Loader script (works with or without torchao) |
| `inference_optimized.py` | 4.3 KB | Full pipeline with INT8 + Gemma CPU offload |
You still need the other components from [ResembleAI/Dramabox](https://huggingface.co/ResembleAI/Dramabox):
- `dramabox-audio-components.safetensors` (1.9 GB) β€” VAE + vocoder
- [unsloth/gemma-3-12b-it-bnb-4bit](https://huggingface.co/unsloth/gemma-3-12b-it-bnb-4bit) (~8 GB) β€” text encoder
## Results
| Metric | Baseline (BF16) | This model (INT8) | Change |
|--------|-----------------|-------------------|--------|
| DiT checkpoint size | 6.1 GB | 3.37 GB | **-45%** |
| Peak VRAM | 17.39 GB | 13.8 GB | **-20.6%** |
| VRAM during denoising | 17.39 GB | 5.93 GB | **-65.9%** |
| Audio quality (MCD) | 0.0 dB | 4.98 dB | Within threshold |
| Generation time | 2.62s | 3.22s | +23% |
MCD (Mel-Cepstral Distortion) measures spectral distance from the BF16 baseline. Lower is better. Scores below 5.0 dB are perceptually near-identical for speech.
## Quantization details
**Method:** Selective INT8 weight-only quantization via [torchao](https://github.com/pytorch/ao) `Int8WeightOnlyConfig`. Weights are stored as INT8 with per-channel BF16 scales and dequantized at runtime during matrix multiplication.
**What's quantized (562 layers, ~81.5% of DiT parameters):**
- All attention projections (`to_q`, `to_k`, `to_v`, `to_out`) across all 48 transformer blocks
- All `gate_logits` layers
- All FFN GELU projections (`audio_ff.net.0.proj`) across all 48 blocks
- FFN output projections (`audio_ff.net.2`) in blocks 15–47, excluding block 17
- Input/output projections (`audio_patchify_proj`, `audio_proj_out`)
**What's NOT quantized (kept in BF16):**
- All normalization layers β€” extremely sensitive to precision changes
- AdaLN conditioning layers β€” controls the diffusion process globally
- Timestep embedder β€” conditioning pathway, highly sensitive
- FFN output projections in blocks 0–14 β€” early blocks are most sensitive to quantization
- FFN output projection in block 17 β€” anomalously sensitive individual block
This layer map was discovered through 80+ automated experiments using [Andrej Karpathy's auto-research methodology](https://github.com/karpathy/autoresearch), systematically testing each layer type and block index.
## Usage
### Option 1: Runtime quantization (simplest, no extra downloads)
If you just want VRAM savings without downloading this checkpoint, you can apply quantization at load time to the original DramaBox model:
```python
import torch, re
from torchao.quantization import quantize_, Int8WeightOnlyConfig
# After loading the standard DramaBox TTSServer:
attn_proj_keys = ("to_q", "to_k", "to_v", "to_out")
def dit_filter(mod, fqn):
if not isinstance(mod, torch.nn.Linear): return False
if "norm" in fqn: return False
if "gate_logits" in fqn: return True
if any(k in fqn for k in attn_proj_keys): return True
if "audio_ff" in fqn:
m = re.search(r'transformer_blocks\.(\d+)\.', fqn)
if m:
idx = int(m.group(1))
if "net.2" in fqn and idx >= 15 and idx != 17: return True
if "net.0.proj" in fqn: return True
return False
def io_filter(mod, fqn):
return fqn in ("audio_patchify_proj", "audio_proj_out") and isinstance(mod, torch.nn.Linear)
quantize_(tts._velocity_model, Int8WeightOnlyConfig(), filter_fn=dit_filter)
quantize_(tts._velocity_model, Int8WeightOnlyConfig(), filter_fn=io_filter)
```
### Option 2: Load pre-quantized weights (faster startup)
```python
from load_int8 import load_int8_dit
# Loads the INT8 safetensors and reconstructs quantized Linear layers
load_int8_dit(tts._velocity_model, "dramabox-dit-int8-selective.safetensors")
```
### Option 3: Full optimized pipeline with Gemma offload
For maximum VRAM savings (5.93 GB during denoising), use the included `inference_optimized.py` which also offloads Gemma 12B to CPU between text encoding and audio generation.
## Requirements
- PyTorch >= 2.4
- torchao >= 0.15.0
- CUDA GPU with >= 16 GB VRAM (14 GB with Gemma offload)
- The original DramaBox model and its dependencies
## How this was made
We ran 80+ experiments using an automated loop inspired by Karpathy's auto-research methodology:
1. Start from the BF16 baseline
2. Modify quantization config (which layers, which precision, which blocks)
3. Generate 3 evaluation audio samples with fixed prompts/seeds
4. Measure peak VRAM, generation time, and MCD vs baseline
5. Keep the change if MCD < 5.0 dB, discard otherwise
6. Repeat
Key findings from the search:
- **Flow-matching diffusion models are far more precision-sensitive than autoregressive LLMs.** All 4-bit approaches (NF4, NVFP4, FP4, Int4) produced unacceptable quality (MCD 17–32 dB).
- **FP8 is worse than INT8** for weight representation in this model (MCD 11.8 vs 4.35).
- **`torch.compile` breaks audio output** even on the unquantized baseline (MCD 24–32 dB). The iterative denoising loop is numerically sensitive to graph optimizations.
- **Early transformer blocks (0–14) are most sensitive** in their FFN output projections. Block 17 is an outlier.
- **Attention projections and GELU gates are universally robust** to INT8 across all 48 blocks.
## Citation
If you use this work, please cite the original DramaBox model:
```bibtex
@misc{dramabox2025,
title={DramaBox: Expressive Text to Speech Model},
author={Resemble AI},
year={2025},
url={https://github.com/resemble-ai/DramaBox}
}
```
## License
Same as the base DramaBox model β€” [LTX-2 Community License](https://huggingface.co/ResembleAI/Dramabox/blob/main/LICENSE).