| from transformers import PretrainedConfig | |
| class APEXConfig(PretrainedConfig): | |
| model_type = "apex" | |
| def __init__( | |
| self, | |
| mert_model_name = "m-a-p/MERT-v1-95M", | |
| layer_indices = [2, 5, 8, -1], | |
| segment_sec = 30, | |
| seed = 42, | |
| input_dim = 768, | |
| shared_dims = [512, 256], | |
| branch_dims = [128, 64], | |
| dropout_shared = 0.3, | |
| dropout_branch = 0.1, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.mert_model_name = mert_model_name | |
| self.layer_indices = layer_indices | |
| self.segment_sec = segment_sec | |
| self.seed = seed | |
| self.input_dim = input_dim | |
| self.shared_dims = shared_dims | |
| self.branch_dims = branch_dims | |
| self.dropout_shared = dropout_shared | |
| self.dropout_branch = dropout_branch |