GibbsTTS / config.py
ydqmkkx's picture
update
0afe769
from dataclasses import dataclass
from typing import Optional
from text import symbols
@dataclass
class ModelConfig:
# model and train configs
n_vocab: int = len(symbols) + 1
n_lang: int = 3
hidden_size: int = 1024
intermediate_size: int = hidden_size * 4
n_heads: int = hidden_size // 64
n_layers: int = 16
dropout: float = 0.
cfg_dropout: float = 0.15
quantizers_num: int = 12
codebook_size: int = 1024
special_codebook_size: int = 1
batch_size: tuple[int, ...] = 16
length_bins: tuple[int, ...] = (100, 256, 384, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536)
t_grid_size: int = 1024
batch_accum: int = 1
learning_rate: float = 2e-4
min_lr_ratio: float = 0.1
grad_clip_thresh: Optional[float] = 1.0
num_epochs: int = 10
num_shards: int = 3
warmup_shards: int = 1.5
log_interval_step: int = 10
save_interval_shard: int = 1
num_workers: int = 8
train: str = False
codebook_weights_path: str = './pretrained/codebook_embeds.pt'
model_save_path: str = './checkpoints/tmp'
log_dir: str = './runs/tmp'
load_ckpt_path: str = None
dataset_path: str = ''
# infer configs
infer_ckpt_dir: str = './pretrained'
infer_ckpt_path: str = './pretrained/GibbsTTS_large_ema.safetensors'
steps: int = 32
rescale_cfg: float = 0.75
cfg: float = 2.5
temperature: float = 0.6
top_p: float = 1.