| from transformers import PretrainedConfig, BertConfig |
| from typing import List |
|
|
| class VGCNConfig(BertConfig): |
| model_type = "vgcn" |
|
|
| def __init__( |
| self, |
| bert_model='readerbench/RoBERT-base', |
| gcn_adj_matrix: str ='', |
| max_seq_len: int = 256, |
| npmi_threshold: float = 0.2, |
| tf_threshold: float = 0.0, |
| vocab_type: str = "all", |
| gcn_embedding_dim: int = 32, |
| **kwargs, |
| ): |
| if vocab_type not in ["all", "pmi", "tf"]: |
| raise ValueError(f"`vocab_type` must be 'all', 'pmi' or 'tf', got {vocab_type}.") |
| if max_seq_len < 1 or max_seq_len > 512: |
| raise ValueError(f"`max_seq_len` must be between 1 and 512, got {max_seq_len}.") |
| if npmi_threshold < 0.0 or npmi_threshold > 1.0: |
| raise ValueError(f"`npmi_threshold` must be between 0.0 and 1.0, got {npmi_threshold}.") |
| if tf_threshold < 0.0 or tf_threshold > 1.0: |
| raise ValueError(f"`tf_threshold` must be between 0.0 and 1.0, got {tf_threshold}.") |
| |
| self.gcn_adj_matrix = gcn_adj_matrix |
| self.max_seq_len = max_seq_len |
| self.npmi_threshold = npmi_threshold |
| self.tf_threshold = tf_threshold |
| self.vocab_type = vocab_type |
| self.gcn_embedding_dim = gcn_embedding_dim |
| self.bert_model = bert_model |
|
|
| super().__init__(**kwargs) |