| import argparse |
| import json |
| import torch |
| import tqdm |
| from transformers import BertTokenizer, BertForSequenceClassification |
| from src.models.DualEncoderModule import DualEncoderModule |
|
|
|
|
| def triple_to_string(x): |
| return " </s> ".join([item.strip() for item in x]) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Rerank the QA paris and keep top 3 QA paris as evidence using a pre-trained BERT model." |
| ) |
| parser.add_argument( |
| "-i", |
| "--top_k_qa_file", |
| default="data_store/dev_top_k_qa.json", |
| help="Json file with claim and top k generated question-answer pairs.", |
| ) |
| parser.add_argument( |
| "-o", |
| "--output_file", |
| default="data_store/dev_top_3_rerank_qa.json", |
| help="Json file with the top3 reranked questions.", |
| ) |
| parser.add_argument( |
| "-ckpt", |
| "--best_checkpoint", |
| type=str, |
| default="pretrained_models/bert_dual_encoder.ckpt", |
| ) |
| parser.add_argument( |
| "--top_n", |
| type=int, |
| default=3, |
| help="top_n question answer pairs as evidence to keep.", |
| ) |
| args = parser.parse_args() |
|
|
| examples = [] |
| with open(args.top_k_qa_file) as f: |
| for line in f: |
| examples.append(json.loads(line)) |
|
|
| bert_model_name = "bert-base-uncased" |
|
|
| tokenizer = BertTokenizer.from_pretrained(bert_model_name) |
| bert_model = BertForSequenceClassification.from_pretrained( |
| bert_model_name, num_labels=2, problem_type="single_label_classification" |
| ) |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| trained_model = DualEncoderModule.load_from_checkpoint( |
| args.best_checkpoint, tokenizer=tokenizer, model=bert_model |
| ).to(device) |
|
|
| with open(args.output_file, "w", encoding="utf-8") as output_file: |
| for example in tqdm.tqdm(examples): |
| strs_to_score = [] |
| values = [] |
|
|
| bm25_qau = example["bm25_qau"] if "bm25_qau" in example else [] |
| claim = example["claim"] |
|
|
| for question, answer, url in bm25_qau: |
| str_to_score = triple_to_string([claim, question, answer]) |
|
|
| strs_to_score.append(str_to_score) |
| values.append([question, answer, url]) |
|
|
| if len(bm25_qau) > 0: |
| encoded_dict = tokenizer( |
| strs_to_score, |
| max_length=512, |
| padding="longest", |
| truncation=True, |
| return_tensors="pt", |
| ).to(device) |
|
|
| input_ids = encoded_dict["input_ids"] |
| attention_masks = encoded_dict["attention_mask"] |
|
|
| scores = torch.softmax( |
| trained_model(input_ids, attention_mask=attention_masks).logits, |
| axis=-1, |
| )[:, 1] |
|
|
| top_n = torch.argsort(scores, descending=True)[: args.top_n] |
| evidence = [ |
| { |
| "question": values[i][0], |
| "answer": values[i][1], |
| "url": values[i][2], |
| } |
| for i in top_n |
| ] |
| else: |
| evidence = [] |
|
|
| json_data = { |
| "claim_id": example["claim_id"], |
| "claim": claim, |
| "evidence": evidence, |
| } |
| output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n") |
| output_file.flush() |
|
|