| from dataclasses import asdict, dataclass |
| from typing import Dict, Optional, List |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| @dataclass |
| class GPTAudioConfig: |
| """Configuration for GPT audio processing parameters""" |
| mel_channels: int = 80 |
| sample_rate: int = 22050 |
| output_sample_rate: int = 24000 |
|
|
| @dataclass |
| class XTTSAudioConfig: |
| """Configuration for audio processing parameters""" |
| sample_rate: int = 22050 |
| output_sample_rate: int = 24000 |
| mel_channels: int = 80 |
| hop_length: int = 256 |
| win_length: int = 1024 |
| n_fft: int = 1024 |
| fmin: int = 0 |
| fmax: int = 8000 |
| power: float = 1.0 |
| mel_norms_file: Optional[str] = None |
|
|
|
|
| class XTTSGPTConfig(PretrainedConfig): |
| """Configuration class for the GPT component of XTTS.""" |
| model_type = "xtts_gpt" |
|
|
| def __init__( |
| self, |
| |
| hidden_size: int = 1024, |
| n_inner: int = 4096, |
| num_hidden_layers: int = 30, |
| num_attention_heads: int = 16, |
| |
| |
| vocab_size: int = 6681, |
| number_text_tokens: int = 6681, |
| start_text_token: Optional[int] = None, |
| stop_text_token: Optional[int] = None, |
| |
| |
| num_audio_tokens: int = 1026, |
| start_audio_token: int = 1024, |
| stop_audio_token: int = 1025, |
| |
| |
| max_audio_tokens: int = 605, |
| max_text_tokens: int = 402, |
| max_prompt_tokens: int = 70, |
| gpt_max_audio_tokens: int = 605, |
| |
| |
| use_masking_gt_prompt_approach: bool = True, |
| use_perceiver_resampler: bool = True, |
| kv_cache: bool = True, |
| enable_redaction: bool = False, |
| |
| |
| gpt_batch_size: int = 1, |
| |
| |
| audio_config: Optional[Dict] = None, |
| |
| |
| layer_norm_epsilon: float = 1e-5, |
| initializer_range: float = 0.02, |
| add_cross_attention: bool = False, |
| scale_attn_by_inverse_layer_idx: bool = False, |
| reorder_and_upcast_attn: bool = False, |
| |
| |
| decoder_input_dim: int = 1024, |
| architectures=["XttsGPT"], |
| auto_map={ |
| "AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig", |
| "AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT", |
| }, |
| activation_function: str = "gelu", |
| attn_pdrop: float = 0.1, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.architectures = architectures |
| self.auto_map = auto_map |
| self.audio_config = GPTAudioConfig( |
| **audio_config if audio_config is not None else {} |
| ) |
| self.activation_function = activation_function |
| self.attn_pdrop = attn_pdrop |
| self.hidden_size = hidden_size |
| self.n_inner = n_inner |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
|
|
| self.vocab_size = vocab_size |
| self.number_text_tokens = number_text_tokens |
| self.start_text_token = start_text_token |
| self.stop_text_token = stop_text_token |
|
|
| self.num_audio_tokens = num_audio_tokens |
| self.start_audio_token = start_audio_token |
| self.stop_audio_token = stop_audio_token |
|
|
| self.max_audio_tokens = max_audio_tokens |
| self.max_text_tokens = max_text_tokens |
| self.max_prompt_tokens = max_prompt_tokens |
| self.gpt_max_audio_tokens = gpt_max_audio_tokens |
|
|
| self.use_masking_gt_prompt_approach = use_masking_gt_prompt_approach |
| self.use_perceiver_resampler = use_perceiver_resampler |
| self.kv_cache = kv_cache |
| self.enable_redaction = enable_redaction |
|
|
| self.gpt_batch_size = gpt_batch_size |
|
|
| self.layer_norm_epsilon = layer_norm_epsilon |
| self.initializer_range = initializer_range |
| self.add_cross_attention = add_cross_attention |
| self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx |
| self.reorder_and_upcast_attn = reorder_and_upcast_attn |
|
|
| self.decoder_input_dim = decoder_input_dim |
|
|
| def to_dict(self) -> Dict: |
| """Convert the config to a dictionary.""" |
| output = super().to_dict() |
| output["audio_config"] = asdict(self.audio_config) |
| return output |
|
|
| @classmethod |
| def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSGPTConfig": |
| """Create a config from a dictionary.""" |
| return cls(**config_dict) |
|
|
|
|
| class XTTSConfig(PretrainedConfig): |
| """Configuration class for XTTS model components except GPT.""" |
| model_type = "xtts" |
|
|
| def __init__( |
| self, |
| |
| audio_config: Optional[Dict] = None, |
| input_sample_rate: int = 22050, |
| output_sample_rate: int = 24000, |
| output_hop_length: int = 256, |
| |
| |
| decoder_input_dim: int = 1024, |
| d_vector_dim: int = 512, |
| cond_d_vector_in_each_upsampling_layer: bool = True, |
| |
| |
| gpt_code_stride_len: int = 1024, |
| duration_const: int = 102400, |
| |
| |
| tokenizer_file: str = "", |
| num_chars: int = 255, |
| |
| |
| languages: Optional[List[str]] = None, |
| |
| |
| gpt_config: Optional[Dict] = None, |
| architectures=["Xtts"], |
| auto_map = { |
| "AutoConfig": "AstraMindAI/xtts2--xtts2_config.XTTSConfig", |
| "AutoModelForCausalLM": "AstraMindAI/xtts2--xtts2_modeling.Xtts", |
| }, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.architectures = architectures |
| self.auto_map = auto_map |
| |
| self.audio_config = XTTSAudioConfig( |
| **audio_config if audio_config is not None else {} |
| ) |
|
|
| self.input_sample_rate = input_sample_rate |
| self.output_sample_rate = output_sample_rate |
| self.output_hop_length = output_hop_length |
|
|
| self.decoder_input_dim = decoder_input_dim |
| self.d_vector_dim = d_vector_dim |
| self.cond_d_vector_in_each_upsampling_layer = cond_d_vector_in_each_upsampling_layer |
|
|
| self.gpt_code_stride_len = gpt_code_stride_len |
| self.duration_const = duration_const |
|
|
| self.tokenizer_file = tokenizer_file |
| self.num_chars = num_chars |
|
|
| |
| self.gpt = XTTSGPTConfig(**gpt_config if gpt_config is not None else {}) |
|
|
| if languages is None: |
| self.languages = [ |
| "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", |
| "nl", "cs", "ar", "zh-cn", "hu", "ko", "ja", "hi" |
| ] |
| else: |
| self.languages = languages |
|
|
| def to_dict(self) -> Dict: |
| """Convert the config to a dictionary.""" |
| output = super().to_dict() |
| output["audio_config"] = asdict(self.audio_config) |
| output["gpt_config"] = self.gpt.to_dict() |
| return output |
|
|
| @classmethod |
| def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSConfig": |
| """Create a config from a dictionary.""" |
| if "gpt_config" in config_dict: |
| gpt_config = config_dict["gpt_config"] |
| config_dict = {k: v for k, v in config_dict.items() if k != "gpt_config"} |
| return cls(gpt_config=gpt_config, **config_dict) |
| return cls(**config_dict) |