from transformers import PretrainedConfig class MusicFMConfig(PretrainedConfig): model_type = "musicfm" def __init__( self, num_codebooks: int = 1, codebook_dim: int = 16, codebook_size: int = 4096, features: list[str] = ["melspec_2048"], hop_length: int = 240, n_mels: int = 128, conv_dim: int = 512, encoder_dim: int = 1024, encoder_depth: int = 12, mask_hop: float = 0.4, mask_prob: float = 0.6, is_flash: bool = False, stat: dict[str, float] = {}, **kwargs, ) -> None: super().__init__(**kwargs) self.num_codebooks = num_codebooks self.codebook_dim = codebook_dim self.codebook_size = codebook_size self.features = features self.hop_length = hop_length self.n_mels = n_mels self.conv_dim = conv_dim self.encoder_dim = encoder_dim self.encoder_depth = encoder_depth self.mask_hop = mask_hop self.mask_prob = mask_prob self.is_flash = is_flash self.stat = stat class MusicFMInferenceConfig(MusicFMConfig): model_type = "musicfm_inference" def __init__( self, num_codebooks: int = 1, codebook_dim: int = 16, codebook_size: int = 4096, features: list[str] = ["melspec_2048"], hop_length: int = 240, n_mels: int = 128, conv_dim: int = 512, encoder_dim: int = 1024, encoder_depth: int = 12, mask_hop: float = 0.4, mask_prob: float = 0.6, is_flash: bool = False, layer_index: int = 9, stat: dict[str, float] = {}, **kwargs, ) -> None: super().__init__( num_codebooks=num_codebooks, codebook_dim=codebook_dim, codebook_size=codebook_size, features=features, hop_length=hop_length, n_mels=n_mels, conv_dim=conv_dim, encoder_dim=encoder_dim, encoder_depth=encoder_depth, mask_hop=mask_hop, mask_prob=mask_prob, is_flash=is_flash, stat=stat, **kwargs, ) self.layer_index = layer_index