Update modeling_apex.py
Browse files- modeling_apex.py +1 -4
modeling_apex.py
CHANGED
|
@@ -8,9 +8,8 @@ from transformers import PreTrainedModel, AutoProcessor, AutoModel
|
|
| 8 |
from .configuration_apex import APEXConfig
|
| 9 |
|
| 10 |
|
| 11 |
-
|
| 12 |
# BUILDING BLOCKS
|
| 13 |
-
# -------------------------------
|
| 14 |
class SharedBlock(nn.Module):
|
| 15 |
def __init__(self, in_dim, out_dim, dropout):
|
| 16 |
super().__init__()
|
|
@@ -55,9 +54,7 @@ class TaskBranch(nn.Module):
|
|
| 55 |
return torch.sigmoid(self.branch(x)) * self.scale + self.shift
|
| 56 |
|
| 57 |
|
| 58 |
-
# -------------------------------
|
| 59 |
# APEX MODEL
|
| 60 |
-
# -------------------------------
|
| 61 |
class APEXModel(PreTrainedModel):
|
| 62 |
config_class = APEXConfig
|
| 63 |
_keys_to_ignore_on_load_missing = [r"mert\..*", r"mert_processor\..*"]
|
|
|
|
| 8 |
from .configuration_apex import APEXConfig
|
| 9 |
|
| 10 |
|
| 11 |
+
|
| 12 |
# BUILDING BLOCKS
|
|
|
|
| 13 |
class SharedBlock(nn.Module):
|
| 14 |
def __init__(self, in_dim, out_dim, dropout):
|
| 15 |
super().__init__()
|
|
|
|
| 54 |
return torch.sigmoid(self.branch(x)) * self.scale + self.shift
|
| 55 |
|
| 56 |
|
|
|
|
| 57 |
# APEX MODEL
|
|
|
|
| 58 |
class APEXModel(PreTrainedModel):
|
| 59 |
config_class = APEXConfig
|
| 60 |
_keys_to_ignore_on_load_missing = [r"mert\..*", r"mert_processor\..*"]
|