| from dataclasses import asdict, dataclass |
| from typing import Dict, Optional, List |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| @dataclass |
| class GPTAudioConfig: |
| """Configuration for GPT audio processing parameters""" |
| mel_channels: int = 80 |
| sample_rate: int = 22050 |
| output_sample_rate: int = 24000 |
|
|
| @dataclass |
| class XTTSAudioConfig: |
| """Configuration for audio processing parameters""" |
| sample_rate: int = 22050 |
| output_sample_rate: int = 24000 |
| mel_channels: int = 80 |
| hop_length: int = 256 |
| win_length: int = 1024 |
| n_fft: int = 1024 |
| fmin: int = 0 |
| fmax: int = 8000 |
| power: float = 1.0 |
| mel_norms_file: Optional[str] = None |
|
|
|
|
| class XTTSGPTConfig(PretrainedConfig): |
| """Configuration class for the GPT component of XTTS.""" |
| model_type = "xtts_gpt" |
|
|
| def __init__( |
| self, |
| |
| hidden_size: int = 1024, |
| n_inner: int = 4096, |
| num_hidden_layers: int = 30, |
| num_attention_heads: int = 16, |
| |
| |
| vocab_size: int = 6681, |
| number_text_tokens: int = 6681, |
| start_text_token: Optional[int] = None, |
| stop_text_token: Optional[int] = None, |
| |
| |
| num_audio_tokens: int = 1026, |
| start_audio_token: int = 1024, |
| stop_audio_token: int = 1025, |
| |
| |
| max_audio_tokens: int = 605, |
| max_text_tokens: int = 402, |
| max_prompt_tokens: int = 70, |
| gpt_max_audio_tokens: int = 605, |
| |
| |
| use_masking_gt_prompt_approach: bool = True, |
| use_perceiver_resampler: bool = True, |
| kv_cache: bool = True, |
| enable_redaction: bool = False, |
| |
| |
| gpt_batch_size: int = 1, |
| |
| |
| audio_config: Optional[Dict] = None, |
| |
| |
| layer_norm_epsilon: float = 1e-5, |
| initializer_range: float = 0.02, |
| add_cross_attention: bool = False, |
| scale_attn_by_inverse_layer_idx: bool = False, |
| reorder_and_upcast_attn: bool = False, |
| |
| |
| decoder_input_dim: int = 1024, |
| architectures=["XttsGPT"], |
| auto_map={ |
| "AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig", |
| "AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT", |
| }, |
| activation_function: str = "gelu", |
| attn_pdrop: float = 0.1, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.architectures = architectures |
| self.auto_map = auto_map |
| self.audio_config = GPTAudioConfig( |
| **audio_config if audio_config is not None else {} |
| ) |
| self.activation_function = activation_function |
| self.attn_pdrop = attn_pdrop |
| self.hidden_size = hidden_size |
| self.n_inner = n_inner |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
|
|
| self.vocab_size = vocab_size |
| self.number_text_tokens = number_text_tokens |
| self.start_text_token = start_text_token |
| self.stop_text_token = stop_text_token |
|
|
| self.num_audio_tokens = num_audio_tokens |
| self.start_audio_token = start_audio_token |
| self.stop_audio_token = stop_audio_token |
|
|
| self.max_audio_tokens = max_audio_tokens |
| self.max_text_tokens = max_text_tokens |
| self.max_prompt_tokens = max_prompt_tokens |
| self.gpt_max_audio_tokens = gpt_max_audio_tokens |
|
|
| self.use_masking_gt_prompt_approach = use_masking_gt_prompt_approach |
| self.use_perceiver_resampler = use_perceiver_resampler |
| self.kv_cache = kv_cache |
| self.enable_redaction = enable_redaction |
|
|
| self.gpt_batch_size = gpt_batch_size |
|
|
| self.layer_norm_epsilon = layer_norm_epsilon |
| self.initializer_range = initializer_range |
| self.add_cross_attention = add_cross_attention |
| self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx |
| self.reorder_and_upcast_attn = reorder_and_upcast_attn |
|
|
| self.decoder_input_dim = decoder_input_dim |
|
|
| def to_dict(self) -> Dict: |
| """Convert the config to a dictionary.""" |
| output = super().to_dict() |
| output["audio_config"] = asdict(self.audio_config) |
| return output |
|
|
| @classmethod |
| def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSGPTConfig": |
| """Create a config from a dictionary.""" |
| return cls(**config_dict) |
|
|
|
|
|
|