| import os
|
| os.environ["TF_ENABLE_ONEDNN_OPTS"] = '0'
|
|
|
| from huggingface_hub import login
|
|
|
|
|
| from typing import Union, Any, Dict
|
|
|
|
|
| import argparse
|
| import datasets
|
| from transformers.utils import logging, check_min_version
|
| from transformers.utils.versions import require_version
|
|
|
| from retro_reader import RetroReader
|
| from retro_reader.constants import EXAMPLE_FEATURES
|
| import torch
|
|
|
|
|
| check_min_version("4.13.0.dev0")
|
|
|
| require_version("datasets>=1.8.0")
|
|
|
| logger = logging.get_logger(__name__)
|
|
|
|
|
| def schema_integrate(example) -> Union[Dict, Any]:
|
| title = example["title"]
|
| question = example["question"]
|
| context = example["context"]
|
| guid = example["id"]
|
| classtype = [""] * len(title)
|
| dataset_name = source = ["squad_v2"] * len(title)
|
| answers, is_impossible = [], []
|
| for answer_examples in example["answers"]:
|
| if answer_examples["text"]:
|
| answers.append(answer_examples)
|
| is_impossible.append(False)
|
| else:
|
| answers.append({"text": [""], "answer_start": [-1]})
|
| is_impossible.append(True)
|
|
|
| return {
|
| "guid": guid,
|
| "question": question,
|
| "context": context,
|
| "answers": answers,
|
| "title": title,
|
| "classtype": classtype,
|
| "source": source,
|
| "is_impossible": is_impossible,
|
| "dataset": dataset_name,
|
| }
|
|
|
|
|
|
|
| def data_aug_for_multiple_answers(examples) -> Union[Dict, Any]:
|
| result = {key: [] for key in examples.keys()}
|
|
|
| def update(i, answers=None):
|
| for key in result.keys():
|
| if key == "answers" and answers is not None:
|
| result[key].append(answers)
|
| else:
|
| result[key].append(examples[key][i])
|
|
|
| for i, (answers, unanswerable) in enumerate(
|
| zip(examples["answers"], examples["is_impossible"])
|
| ):
|
| answerable = not unanswerable
|
| assert (
|
| len(answers["text"]) == len(answers["answer_start"]) or
|
| answers["answer_start"][0] == -1
|
| )
|
| if answerable and len(answers["text"]) > 1:
|
| for n_ans in range(len(answers["text"])):
|
| ans = {
|
| "text": [answers["text"][n_ans]],
|
| "answer_start": [answers["answer_start"][n_ans]],
|
| }
|
| update(i, ans)
|
| elif not answerable:
|
| update(i, {"text": [], "answer_start": []})
|
| else:
|
| update(i)
|
|
|
| return result
|
|
|
|
|
| def main(args):
|
|
|
| print("Loading SQuAD v2.0 dataset ...")
|
| squad_v2 = datasets.load_dataset("squad_v2")
|
|
|
|
|
|
|
|
|
|
|
|
|
| if args.debug:
|
| squad_v2["train"] = squad_v2["train"].select(range(5))
|
| squad_v2["validation"] = squad_v2["validation"].select(range(5))
|
|
|
| print("Integrating into the schema used in this library ...")
|
| squad_v2 = squad_v2.map(
|
| schema_integrate,
|
| batched=True,
|
| remove_columns=squad_v2.column_names["train"],
|
| features=EXAMPLE_FEATURES,
|
| )
|
|
|
|
|
| num_unanswerable_train = sum(squad_v2["train"]["is_impossible"])
|
| num_unanswerable_valid = sum(squad_v2["validation"]["is_impossible"])
|
| logger.warning(f"Number of unanswerable sample for SQuAD v2.0 train dataset: {num_unanswerable_train}")
|
| logger.warning(f"Number of unanswerable sample for SQuAD v2.0 validation dataset: {num_unanswerable_valid}")
|
|
|
|
|
|
|
| print("Data augmentation for multiple answers ...")
|
| squad_v2_train = squad_v2["train"].map(
|
| data_aug_for_multiple_answers,
|
| batched=True,
|
| batch_size=args.batch_size,
|
| num_proc=5,
|
| )
|
| squad_v2 = datasets.DatasetDict({
|
| "train": squad_v2_train,
|
| "validation": squad_v2["validation"]
|
| })
|
|
|
|
|
|
|
|
|
|
|
| print("Loading Retro Reader ...")
|
| retro_reader = RetroReader.load(
|
| train_examples=squad_v2["train"],
|
| eval_examples=squad_v2["validation"],
|
| config_file=args.configs,
|
| device="cuda" if torch.cuda.is_available() else "cpu",
|
| )
|
| if args.resume_checkpoint:
|
| retro_reader = retro_reader.load_checkpoint(args.resume_checkpoint)
|
|
|
|
|
| print("Training ...")
|
| retro_reader.train(module=args.module)
|
| logger.warning("Train retrospective reader Done.")
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--configs", "-c", type=str, default="configs/train_distilbert.yaml", help="config file path")
|
| parser.add_argument("--batch_size", "-b", type=int, default=1024, help="batch size")
|
| parser.add_argument("--resume_checkpoint", "-r", type=str, default=None, help="resume checkpoint path")
|
| parser.add_argument("--module", "-m", type=str, default="all", choices=["all", "sketch", "intensive"], help="module to train")
|
| parser.add_argument("--debug", "-d", action="store_true", help="debug mode")
|
| args = parser.parse_args()
|
| main(args) |