from transformers import PretrainedConfig class APEXConfig(PretrainedConfig): model_type = "apex" def __init__( self, mert_model_name = "m-a-p/MERT-v1-95M", layer_indices = [2, 5, 8, -1], segment_sec = 30, seed = 42, input_dim = 768, shared_dims = [512, 256], branch_dims = [128, 64], dropout_shared = 0.3, dropout_branch = 0.1, **kwargs ): super().__init__(**kwargs) self.mert_model_name = mert_model_name self.layer_indices = layer_indices self.segment_sec = segment_sec self.seed = seed self.input_dim = input_dim self.shared_dims = shared_dims self.branch_dims = branch_dims self.dropout_shared = dropout_shared self.dropout_branch = dropout_branch