Veda TTS 80M

A lightweight 80M parameter text-to-speech model distilled from a 206M parameter teacher, generating high-quality speech using the SNAC neural audio codec.

Model Details

Property Value
Parameters 80M
Architecture CGNv2 Transformer (768d, 12 layers, 12 heads, 6 KV heads)
Audio Codec SNAC 24kHz (3-level, 4096 codebook, 84 tokens/sec)
Training Data LJSpeech
Training Method Knowledge distillation (Ξ±=0.7) from 206M teacher
Vocab Size 12,363 tokens
Max Sequence Length 4,096
Sample Rate 24,000 Hz

Architecture

Veda TTS uses a Chain-of-Generation (CGN v2) architecture with prosodic chain-of-thought:

  1. Text β†’ Phonemes: Flite G2P converts input text to phonemes
  2. Prosodic Planning: Model generates a prosodic plan (duration, pitch, emphasis, pauses)
  3. Audio Generation: Model generates SNAC audio tokens autoregressively
  4. SNAC Decoding: SNAC codec decodes tokens to 24kHz waveform

The model uses:

  • Grouped Query Attention (GQA) with 12 query heads and 6 KV heads
  • RoPE positional encoding
  • SwiGLU feed-forward networks (d_ff=3072)
  • Classifier-free guidance support (cfg_drop_prob=0.1)
  • Weight-untied output projection

Performance

Metric Value
WER 0.247
DNSMOS 3.31
RTF (GPU) 1.78
RTF (ONNX INT8 CPU) 1.11

ONNX Benchmark (CPU)

Method Tokens/sec Avg Time (512 tokens)
PyTorch CPU (KV cache) 43.5 tok/s 10.9s
ONNX FP32 CPU 42.8 tok/s 9.5s
ONNX INT8 CPU 75.9 tok/s 5.2s

File Sizes

File Size
model.safetensors (PyTorch) 503 MB
ONNX FP32 (prefill + decode) 503 MB each
ONNX INT8 (prefill + decode) 128 MB each

Files

β”œβ”€β”€ model.safetensors          # PyTorch weights (safetensors format)
β”œβ”€β”€ config.json                # Model configuration
β”œβ”€β”€ tokenizer.json             # Phoneme tokenizer (41 ARPAbet phonemes)
β”œβ”€β”€ benchmark_results.json     # Detailed benchmark data
└── onnx/
    β”œβ”€β”€ cgn_v2_prefill.onnx        # ONNX FP32 prefill model
    β”œβ”€β”€ cgn_v2_decode.onnx         # ONNX FP32 decode model (with KV cache)
    β”œβ”€β”€ cgn_v2_prefill_int8.onnx   # ONNX INT8 quantized prefill
    └── cgn_v2_decode_int8.onnx    # ONNX INT8 quantized decode

Usage

PyTorch Inference

import torch
from safetensors.torch import load_file
from vedatts.models.cgn_v2.config import CGNv2Config
from vedatts.models.cgn_v2.model import CGNv2
from vedatts.models.cgn_v2.tokenizer import CGNv2Tokenizer
from vedatts.models.cgn_v2.generate import synthesize, tokens_to_audio
from vedatts.codec.snac import SNACCodec

# Load model
config = CGNv2Config(
    d_model=768, n_layers=12, n_heads=12, n_kv_heads=6,
    d_ff=3072, vocab_size=12363, weight_tying=False, dropout=0.0,
    n_speakers=902,
)
model = CGNv2(config)
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict, strict=True)
model.eval()

# Load tokenizer
tokenizer = CGNv2Tokenizer.load("tokenizer.json")

# Load SNAC codec
codec = SNACCodec(device="cpu")

# Synthesize
text = "Hello world, this is Veda TTS."
_, token_ids = synthesize(model, tokenizer, text, device="cpu", temperature=0.8)
audio = tokens_to_audio(token_ids, tokenizer, codec)

# Save
import torchaudio
torchaudio.save("output.wav", audio.squeeze(0).cpu(), 24000)

ONNX Inference (CPU)

import numpy as np
import onnxruntime as ort

# Load ONNX sessions
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL  # For INT8
opts.intra_op_num_threads = 4

prefill = ort.InferenceSession("onnx/cgn_v2_prefill_int8.onnx", opts, providers=["CPUExecutionProvider"])
decode = ort.InferenceSession("onnx/cgn_v2_decode_int8.onnx", opts, providers=["CPUExecutionProvider"])

# Build prompt (requires Flite for G2P)
# prompt = [BOS] + phoneme_ids + [PLAN_START]

# Run prefill
outputs = prefill.run(None, {"input_ids": prompt_array})
logits = outputs[0]
kv_caches = outputs[1:]  # KV cache for each layer

# Autoregressive decode loop
for step in range(max_tokens):
    token = sample(logits)
    if token == EOS:
        break
    decode_input = {"input_id": [[token]], **kv_cache_dict}
    outputs = decode.run(None, decode_input)
    logits = outputs[0]
    # Update KV caches...

Training

This model was trained using knowledge distillation from a 206M parameter teacher model:

  • Teacher: CGNv2 Base (1024d, 16L, 206M params) trained on LJSpeech
  • Student: CGNv2 Small (768d, 12L, 80M params)
  • Distillation: Ξ±=0.7 (70% soft labels from teacher, 30% hard labels)
  • Data: LJSpeech (~24 hours of single-speaker English audiobook recordings)
  • Checkpoint: Step 35,000

Limitations

  • Single-speaker model (LJSpeech voice) β€” multi-speaker support via speaker embeddings is architecturally supported but not trained
  • English only (ARPAbet phoneme set)
  • Requires Flite for grapheme-to-phoneme conversion
  • Requires SNAC codec (hubertsiuzdak/snac_24khz) for audio decoding

Citation

@misc{vedatts2025,
  title={Veda TTS: Lightweight Text-to-Speech with Prosodic Chain-of-Thought},
  author={Sai Krishna Rallabandi},
  year={2025},
}

License

Apache 2.0

Downloads last month
2
Safetensors
Model size
0.1B params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Space using vijayavedartham/veda-tts-80m 1