Update configuration_ced.py
Browse files- configuration_ced.py +9 -5
configuration_ced.py
CHANGED
|
@@ -14,7 +14,7 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
""" CED model configuration"""
|
| 16 |
|
| 17 |
-
|
| 18 |
from transformers import PretrainedConfig
|
| 19 |
from transformers.utils import logging
|
| 20 |
from transformers.utils.hub import cached_file
|
|
@@ -127,10 +127,14 @@ class CedConfig(PretrainedConfig):
|
|
| 127 |
|
| 128 |
if self.outputdim == 527:
|
| 129 |
with open(cached_file("topel/ConvNeXt-Tiny-AT", "class_labels_indices.csv"), "r") as f:
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
self.label2id = {v: k for k, v in self.id2label.items()}
|
| 135 |
else:
|
| 136 |
self.id2label = None
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
""" CED model configuration"""
|
| 16 |
|
| 17 |
+
import csv
|
| 18 |
from transformers import PretrainedConfig
|
| 19 |
from transformers.utils import logging
|
| 20 |
from transformers.utils.hub import cached_file
|
|
|
|
| 127 |
|
| 128 |
if self.outputdim == 527:
|
| 129 |
with open(cached_file("topel/ConvNeXt-Tiny-AT", "class_labels_indices.csv"), "r") as f:
|
| 130 |
+
reader = csv.reader(f)
|
| 131 |
+
next(reader) # skip header
|
| 132 |
+
self.id2label = {}
|
| 133 |
+
for row in reader:
|
| 134 |
+
idx = int(row[0])
|
| 135 |
+
label = row[2]
|
| 136 |
+
if label not in self.id2label.values():
|
| 137 |
+
self.id2label[idx] = label
|
| 138 |
self.label2id = {v: k for k, v in self.id2label.items()}
|
| 139 |
else:
|
| 140 |
self.id2label = None
|