| |
| |
| |
|
|
| import logging |
| import os |
|
|
| import contextlib |
| from typing import Optional |
|
|
| import numpy as np |
| from unicore.data import ( |
| Dictionary, |
| MaskTokensDataset, |
| NestedDictionaryDataset, |
| NumelDataset, |
| NumSamplesDataset, |
| LMDBDataset, |
| PrependTokenDataset, |
| RightPadDataset, |
| SortDataset, |
| BertTokenizeDataset, |
| data_utils, |
| ) |
| from unicore.tasks import UnicoreTask, register_task |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @register_task("bert") |
| class BertTask(UnicoreTask): |
| """Task for training masked language models (e.g., BERT).""" |
|
|
| @staticmethod |
| def add_args(parser): |
| """Add task-specific arguments to the parser.""" |
| parser.add_argument( |
| "data", |
| help="colon separated path to data directories list, \ |
| will be iterated upon during epochs in round-robin manner", |
| ) |
| parser.add_argument( |
| "--mask-prob", |
| default=0.15, |
| type=float, |
| help="probability of replacing a token with mask", |
| ) |
| parser.add_argument( |
| "--leave-unmasked-prob", |
| default=0.1, |
| type=float, |
| help="probability that a masked token is unmasked", |
| ) |
| parser.add_argument( |
| "--random-token-prob", |
| default=0.1, |
| type=float, |
| help="probability of replacing a token with a random token", |
| ) |
|
|
| def __init__(self, args, dictionary): |
| super().__init__(args) |
| self.dictionary = dictionary |
| self.seed = args.seed |
|
|
| |
| self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True) |
|
|
| @classmethod |
| def setup_task(cls, args, **kwargs): |
| dictionary = Dictionary.load(os.path.join(args.data, "dict.txt")) |
| logger.info("dictionary: {} types".format(len(dictionary))) |
| return cls(args, dictionary) |
|
|
| def load_dataset(self, split, combine=False, **kwargs): |
| """Load a given dataset split. |
| Args: |
| split (str): name of the split (e.g., train, valid, test) |
| """ |
| split_path = os.path.join(self.args.data, split + '.lmdb') |
| dict_path = os.path.join(self.args.data, "dict.txt") |
|
|
| dataset = LMDBDataset(split_path) |
| dataset = BertTokenizeDataset(dataset, dict_path, max_seq_len=self.args.max_seq_len) |
|
|
| src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( |
| dataset, |
| self.dictionary, |
| pad_idx=self.dictionary.pad(), |
| mask_idx=self.mask_idx, |
| seed=self.args.seed, |
| mask_prob=self.args.mask_prob, |
| leave_unmasked_prob=self.args.leave_unmasked_prob, |
| random_token_prob=self.args.random_token_prob, |
| ) |
|
|
| with data_utils.numpy_seed(self.args.seed): |
| shuffle = np.random.permutation(len(src_dataset)) |
|
|
| self.datasets[split] = SortDataset( |
| NestedDictionaryDataset( |
| { |
| "net_input": { |
| "src_tokens": RightPadDataset( |
| src_dataset, |
| pad_idx=self.dictionary.pad(), |
| ) |
| }, |
| "target": RightPadDataset( |
| tgt_dataset, |
| pad_idx=self.dictionary.pad(), |
| ), |
| }, |
| ), |
| sort_order=[ |
| shuffle |
| ], |
| ) |
|
|
| def build_model(self, args): |
| from unicore import models |
| model = models.build_model(args, self) |
| return model |
|
|