File size: 900 Bytes
7f39b61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import PretrainedConfig

class APEXConfig(PretrainedConfig):
    model_type = "apex"

    def __init__(
        self,
        mert_model_name = "m-a-p/MERT-v1-95M",
        layer_indices   = [2, 5, 8, -1],
        segment_sec     = 30,
        seed            = 42,
        input_dim       = 768,
        shared_dims     = [512, 256],
        branch_dims     = [128, 64],
        dropout_shared  = 0.3,
        dropout_branch  = 0.1,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.mert_model_name = mert_model_name
        self.layer_indices   = layer_indices
        self.segment_sec     = segment_sec
        self.seed            = seed
        self.input_dim       = input_dim
        self.shared_dims     = shared_dims
        self.branch_dims     = branch_dims
        self.dropout_shared  = dropout_shared
        self.dropout_branch  = dropout_branch