File size: 900 Bytes
7f39b61 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | from transformers import PretrainedConfig
class APEXConfig(PretrainedConfig):
model_type = "apex"
def __init__(
self,
mert_model_name = "m-a-p/MERT-v1-95M",
layer_indices = [2, 5, 8, -1],
segment_sec = 30,
seed = 42,
input_dim = 768,
shared_dims = [512, 256],
branch_dims = [128, 64],
dropout_shared = 0.3,
dropout_branch = 0.1,
**kwargs
):
super().__init__(**kwargs)
self.mert_model_name = mert_model_name
self.layer_indices = layer_indices
self.segment_sec = segment_sec
self.seed = seed
self.input_dim = input_dim
self.shared_dims = shared_dims
self.branch_dims = branch_dims
self.dropout_shared = dropout_shared
self.dropout_branch = dropout_branch |