Jaavid25 commited on
Commit
aa9dfd4
·
verified ·
1 Parent(s): 958eb1a

Upload modeling_apex.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_apex.py +4 -0
modeling_apex.py CHANGED
@@ -61,6 +61,7 @@ class TaskBranch(nn.Module):
61
  class APEXModel(PreTrainedModel):
62
  config_class = APEXConfig
63
  _keys_to_ignore_on_load_missing = [r"mert\..*", r"mert_processor\..*"]
 
64
 
65
  def __init__(self, config: APEXConfig):
66
  super().__init__(config)
@@ -110,6 +111,9 @@ class APEXModel(PreTrainedModel):
110
  self.branch_clarity = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
111
  self.branch_naturalness = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
112
 
 
 
 
113
  def forward(self, embedding):
114
  shared = self.shared(embedding)
115
  return {
 
61
  class APEXModel(PreTrainedModel):
62
  config_class = APEXConfig
63
  _keys_to_ignore_on_load_missing = [r"mert\..*", r"mert_processor\..*"]
64
+ _tied_weights_keys = []
65
 
66
  def __init__(self, config: APEXConfig):
67
  super().__init__(config)
 
111
  self.branch_clarity = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
112
  self.branch_naturalness = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
113
 
114
+ def _init_weights(self, module):
115
+ pass
116
+
117
  def forward(self, embedding):
118
  shared = self.shared(embedding)
119
  return {