| import torch |
| from transformers import PretrainedConfig |
| from typing import List |
| from pdb import set_trace |
|
|
| class MultiLabelClassifierConfig(PretrainedConfig): |
| model_type = "multi_label_classification" |
| problem_type = "multi_label_classification" |
|
|
| def __init__( |
| self, |
| embedding_dim: int=768, |
| labels: List[str]=[], |
| transformer_name: str = "bert-base-uncased", |
| hidden_dim: int = 256, |
| num_layers: int = 2, |
| bidirectional: bool = True, |
| dropout: float =.3, |
| **kwargs, |
| ): |
| self.transformer_name = transformer_name |
| self.hidden_dim = hidden_dim |
| self.labels = labels |
| self.num_layers = num_layers |
| self.bidirectional = bidirectional |
| self.dropout = dropout |
| self.num_classes = len(labels) |
| self.embedding_dim = embedding_dim |
|
|
| |
| if 'id2label' not in kwargs: kwargs['id2label'] = {idx:label for idx, label in enumerate(labels)} |
| if 'label2id' not in kwargs: kwargs['label2id'] = {label:idx for idx, label in enumerate(labels)} |
| super().__init__(**kwargs) |
| |