Update inference.py
Browse files- inference.py +3 -2
inference.py
CHANGED
|
@@ -299,7 +299,8 @@ def task_validate(seq_str, models):
|
|
| 299 |
|
| 300 |
def task_predict(seq_str, models):
|
| 301 |
tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
|
| 302 |
-
|
|
|
|
| 303 |
seq = parse_sequence(seq_str)
|
| 304 |
preds = bert_predict_mask(seq, tok, mlm, top_k=5)
|
| 305 |
print(f"\n Input: {seq_str}")
|
|
@@ -406,4 +407,4 @@ def main():
|
|
| 406 |
|
| 407 |
|
| 408 |
if __name__ == "__main__":
|
| 409 |
-
main()
|
|
|
|
| 299 |
|
| 300 |
def task_predict(seq_str, models):
|
| 301 |
tok, cls, ngram, elec_tok, elec_disc, glyph_map = models
|
| 302 |
+
model_dir, data_dir = get_model_dir()
|
| 303 |
+
mlm = load_bert_mlm(model_dir)
|
| 304 |
seq = parse_sequence(seq_str)
|
| 305 |
preds = bert_predict_mask(seq, tok, mlm, top_k=5)
|
| 306 |
print(f"\n Input: {seq_str}")
|
|
|
|
| 407 |
|
| 408 |
|
| 409 |
if __name__ == "__main__":
|
| 410 |
+
main()
|