hellosindh commited on
Commit
60e393a
·
verified ·
1 Parent(s): 106df91

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +3 -3
inference.py CHANGED
@@ -80,13 +80,13 @@ def load_tokenizer(data_dir):
80
  def load_bert_mlm(model_dir):
81
  from transformers import BertForMaskedLM
82
  return BertForMaskedLM.from_pretrained(
83
- str(model_dir / "mlm" / "best")).to(device).eval()
84
 
85
 
86
  def load_bert_cls(model_dir):
87
  from transformers import BertForSequenceClassification
88
  return BertForSequenceClassification.from_pretrained(
89
- str(model_dir / "cls" / "best")).to(device).eval()
90
 
91
 
92
  def load_ngram(model_dir):
@@ -112,7 +112,7 @@ def load_electra(model_dir):
112
  attention_mask=attention_mask)
113
  return self.classifier(self.dropout(out.last_hidden_state))
114
 
115
- p = model_dir / "electra" / "best"
116
  with open(p / "discriminator_config.json") as f:
117
  cfg = json.load(f)
118
  m = ElectraDisc(BertConfig(**cfg))
 
80
  def load_bert_mlm(model_dir):
81
  from transformers import BertForMaskedLM
82
  return BertForMaskedLM.from_pretrained(
83
+ str(model_dir / "mlm")).to(device).eval()
84
 
85
 
86
  def load_bert_cls(model_dir):
87
  from transformers import BertForSequenceClassification
88
  return BertForSequenceClassification.from_pretrained(
89
+ str(model_dir / "cls")).to(device).eval()
90
 
91
 
92
  def load_ngram(model_dir):
 
112
  attention_mask=attention_mask)
113
  return self.classifier(self.dropout(out.last_hidden_state))
114
 
115
+ p = model_dir / "electra"
116
  with open(p / "discriminator_config.json") as f:
117
  cfg = json.load(f)
118
  m = ElectraDisc(BertConfig(**cfg))