Update inference.py
Browse files- inference.py +3 -3
inference.py
CHANGED
|
@@ -80,13 +80,13 @@ def load_tokenizer(data_dir):
|
|
| 80 |
def load_bert_mlm(model_dir):
|
| 81 |
from transformers import BertForMaskedLM
|
| 82 |
return BertForMaskedLM.from_pretrained(
|
| 83 |
-
str(model_dir / "mlm"
|
| 84 |
|
| 85 |
|
| 86 |
def load_bert_cls(model_dir):
|
| 87 |
from transformers import BertForSequenceClassification
|
| 88 |
return BertForSequenceClassification.from_pretrained(
|
| 89 |
-
str(model_dir / "cls"
|
| 90 |
|
| 91 |
|
| 92 |
def load_ngram(model_dir):
|
|
@@ -112,7 +112,7 @@ def load_electra(model_dir):
|
|
| 112 |
attention_mask=attention_mask)
|
| 113 |
return self.classifier(self.dropout(out.last_hidden_state))
|
| 114 |
|
| 115 |
-
p = model_dir / "electra"
|
| 116 |
with open(p / "discriminator_config.json") as f:
|
| 117 |
cfg = json.load(f)
|
| 118 |
m = ElectraDisc(BertConfig(**cfg))
|
|
|
|
| 80 |
def load_bert_mlm(model_dir):
|
| 81 |
from transformers import BertForMaskedLM
|
| 82 |
return BertForMaskedLM.from_pretrained(
|
| 83 |
+
str(model_dir / "mlm")).to(device).eval()
|
| 84 |
|
| 85 |
|
| 86 |
def load_bert_cls(model_dir):
|
| 87 |
from transformers import BertForSequenceClassification
|
| 88 |
return BertForSequenceClassification.from_pretrained(
|
| 89 |
+
str(model_dir / "cls")).to(device).eval()
|
| 90 |
|
| 91 |
|
| 92 |
def load_ngram(model_dir):
|
|
|
|
| 112 |
attention_mask=attention_mask)
|
| 113 |
return self.classifier(self.dropout(out.last_hidden_state))
|
| 114 |
|
| 115 |
+
p = model_dir / "electra"
|
| 116 |
with open(p / "discriminator_config.json") as f:
|
| 117 |
cfg = json.load(f)
|
| 118 |
m = ElectraDisc(BertConfig(**cfg))
|