| import argparse |
| import time |
| import json |
| import nltk |
| from rank_bm25 import BM25Okapi |
| import numpy as np |
| import torch |
| from transformers import BloomTokenizerFast, BloomForCausalLM |
|
|
|
|
| def claim2prompts(example): |
| claim = example["claim"] |
|
|
| |
| claim_str = "Evidence: " |
|
|
| for question in example["questions"]: |
| q_text = question["question"].strip() |
| if len(q_text) == 0: |
| continue |
|
|
| if not q_text[-1] == "?": |
| q_text += "?" |
|
|
| answer_strings = [] |
|
|
| for a in question["answers"]: |
| if a["answer_type"] in ["Extractive", "Abstractive"]: |
| answer_strings.append(a["answer"]) |
| if a["answer_type"] == "Boolean": |
| answer_strings.append( |
| a["answer"] |
| + ", because " |
| + a["boolean_explanation"].lower().strip() |
| ) |
|
|
| for a_text in answer_strings: |
| if not a_text[-1] in [".", "!", ":", "?"]: |
| a_text += "." |
|
|
| |
| prompt_lookup_str = a_text |
| this_q_claim_str = ( |
| claim_str + " " + a_text.strip() + "||Question answered: " + q_text |
| ) |
| yield ( |
| prompt_lookup_str, |
| this_q_claim_str.replace("\n", " ").replace("||", "\n"), |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Use a prompt to generate questions that could be answered by top-k retrieved evidence. Output generated questions." |
| ) |
| parser.add_argument("--reference_corpus", default="data/train.json", help="") |
| parser.add_argument("--target_file", default="data/dev.json", help="") |
| parser.add_argument( |
| "-i", |
| "--top_k_target_knowledge", |
| default="data_store/dev_top_k_sentences.json", |
| help="Directory where the sentences for the scraped data is saved.", |
| ) |
| parser.add_argument( |
| "-o", |
| "--output_questions", |
| default="data_store/dev_top_k_qa.json", |
| help="Directory where the sentences for the scraped data is saved.", |
| ) |
| parser.add_argument( |
| "--top_k", |
| default=100, |
| type=int, |
| help="How many documents should we pick out with BM25", |
| ) |
| args = parser.parse_args() |
|
|
| |
| with open(args.reference_corpus, "r", encoding="utf-8") as json_file: |
| train_examples = json.load(json_file) |
|
|
| prompt_corpus, tokenized_corpus = [], [] |
|
|
| for example in train_examples: |
| for lookup_str, prompt in claim2prompts(example): |
| entry = nltk.word_tokenize(lookup_str) |
| tokenized_corpus.append(entry) |
| prompt_corpus.append(prompt) |
|
|
| prompt_bm25 = BM25Okapi(tokenized_corpus) |
|
|
| |
| tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1") |
| model = BloomForCausalLM.from_pretrained( |
| "bigscience/bloom-7b1", |
| device_map="auto", |
| torch_dtype=torch.bfloat16, |
| offload_folder="./offload", |
| ) |
|
|
| with open(args.output_questions, "w", encoding="utf-8") as output_file: |
| with open(args.top_k_target_knowledge, "r", encoding="utf-8") as json_file: |
| for i, line in enumerate(json_file): |
| data = json.loads(line) |
| top_k_sentences_urls = data[f"top_{args.top_k}"] |
| claim = data["claim"] |
| claim_id = data["claim_id"] |
|
|
| bm25_qau = [] |
| |
| for sent_i, sentences_urls in enumerate(top_k_sentences_urls): |
|
|
| prompt_lookup_str = sentences_urls["sentence"] |
| url = sentences_urls["url"] |
|
|
| prompt_s = prompt_bm25.get_scores( |
| nltk.word_tokenize(prompt_lookup_str) |
| ) |
| prompt_n = 10 |
| prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n] |
| prompt_docs = [prompt_corpus[i] for i in prompt_top_n] |
|
|
| claim_prompt = ( |
| "Evidence: " |
| + prompt_lookup_str.replace("\n", " ") |
| + "\nQuestion answered: " |
| ) |
|
|
| prompt = "\n\n".join(prompt_docs + [claim_prompt]) |
|
|
| inputs = tokenizer([prompt], padding=True, return_tensors="pt").to( |
| model.device |
| ) |
| st = time.time() |
| outputs = model.generate( |
| inputs["input_ids"], |
| max_length=5000, |
| num_beams=2, |
| no_repeat_ngram_size=2, |
| early_stopping=True, |
| ) |
| print( |
| f"Generated QA for sent {sent_i} in file {i}. Time elapsed: {time.time() - st}" |
| ) |
|
|
| tgt_text = tokenizer.batch_decode( |
| outputs[:, inputs["input_ids"].shape[-1] :], |
| skip_special_tokens=True, |
| )[0] |
|
|
| |
| tgt_text = tgt_text[:250] |
|
|
| qau_pair = [ |
| tgt_text.strip().split("?")[0].replace("\n", " ") + "?", |
| prompt_lookup_str.replace("\n", " "), |
| url, |
| ] |
|
|
| bm25_qau.append(qau_pair) |
|
|
| json_data = { |
| "claim_id": claim_id, |
| "claim": claim, |
| "bm25_qau": bm25_qau, |
| } |
| output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n") |
| output_file.flush() |
|
|