hellosindh commited on
Commit
92c23ad
·
verified ·
1 Parent(s): 3872f06

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +3 -2
inference.py CHANGED
@@ -299,7 +299,8 @@ def task_validate(seq_str, models):
299
 
300
  def task_predict(seq_str, models):
301
  tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
302
- mlm = load_bert_mlm(models[0].__class__) # reload MLM
 
303
  seq = parse_sequence(seq_str)
304
  preds = bert_predict_mask(seq, tok, mlm, top_k=5)
305
  print(f"\n Input: {seq_str}")
@@ -406,4 +407,4 @@ def main():
406
 
407
 
408
  if __name__ == "__main__":
409
- main()
 
299
 
300
  def task_predict(seq_str, models):
301
  tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
302
+ model_dir, data_dir = get_model_dir()
303
+ mlm = load_bert_mlm(model_dir)
304
  seq = parse_sequence(seq_str)
305
  preds = bert_predict_mask(seq, tok, mlm, top_k=5)
306
  print(f"\n Input: {seq_str}")
 
407
 
408
 
409
  if __name__ == "__main__":
410
+ main()