Upload modeling_apex.py with huggingface_hub
Browse files- modeling_apex.py +4 -0
modeling_apex.py
CHANGED
|
@@ -61,6 +61,7 @@ class TaskBranch(nn.Module):
|
|
| 61 |
class APEXModel(PreTrainedModel):
|
| 62 |
config_class = APEXConfig
|
| 63 |
_keys_to_ignore_on_load_missing = [r"mert\..*", r"mert_processor\..*"]
|
|
|
|
| 64 |
|
| 65 |
def __init__(self, config: APEXConfig):
|
| 66 |
super().__init__(config)
|
|
@@ -110,6 +111,9 @@ class APEXModel(PreTrainedModel):
|
|
| 110 |
self.branch_clarity = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
|
| 111 |
self.branch_naturalness = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
|
| 112 |
|
|
|
|
|
|
|
|
|
|
| 113 |
def forward(self, embedding):
|
| 114 |
shared = self.shared(embedding)
|
| 115 |
return {
|
|
|
|
| 61 |
class APEXModel(PreTrainedModel):
|
| 62 |
config_class = APEXConfig
|
| 63 |
_keys_to_ignore_on_load_missing = [r"mert\..*", r"mert_processor\..*"]
|
| 64 |
+
_tied_weights_keys = []
|
| 65 |
|
| 66 |
def __init__(self, config: APEXConfig):
|
| 67 |
super().__init__(config)
|
|
|
|
| 111 |
self.branch_clarity = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
|
| 112 |
self.branch_naturalness = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
|
| 113 |
|
| 114 |
+
def _init_weights(self, module):
|
| 115 |
+
pass
|
| 116 |
+
|
| 117 |
def forward(self, embedding):
|
| 118 |
shared = self.shared(embedding)
|
| 119 |
return {
|