| |
| |
| import gzip |
| import json |
| import math |
| import random |
| import shelve |
| import torch |
|
|
| import subprocess as sp |
|
|
| from math import ceil |
| from torch.utils.data import DataLoader, Sampler, Dataset |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| from env import END_OF_TEXT_TOKEN |
| from gpt2_training.train_utils import (InputFeatures, InputFeatures_train, |
| RedditExample) |
|
|
|
|
| class BucketSampler(Sampler): |
| """ |
| this sampler will sort data by sequence length |
| """ |
| def __init__(self, lens, bucket_size, batch_size, |
| droplast=False, shuffle=True): |
| self._lens = lens |
| self._batch_size = batch_size |
| self._bucket_size = bucket_size |
| self._droplast = droplast |
| self._shuf = shuffle |
|
|
| def __iter__(self): |
| ids = list(range(len(self._lens))) |
| if self._shuf: |
| random.shuffle(ids) |
| buckets = [sorted(ids[i:i+self._bucket_size], |
| key=lambda i: self._lens[i], reverse=True) |
| for i in range(0, len(ids), self._bucket_size)] |
| batches = [bucket[i:i+self._batch_size] |
| for bucket in buckets |
| for i in range(0, len(bucket), self._batch_size)] |
| if self._droplast: |
| batches = [batch for batch in batches |
| if len(batch) == self._batch_size] |
| if self._shuf: |
| random.shuffle(batches) |
| return iter(batches) |
|
|
| def __len__(self): |
| bucket_sizes = ([self._bucket_size] |
| * (len(self._lens) // self._bucket_size) |
| + [len(self._lens) % self._bucket_size]) |
| if self._droplast: |
| return sum(s//self._batch_size for s in bucket_sizes) |
| else: |
| return sum(math.ceil(s/self._batch_size) for s in bucket_sizes) |
|
|
|
|
| class GPT2FeatureDataset(Dataset): |
| """ pytorch dataset for GPT2 training """ |
| def __init__(self, features, max_len=None): |
| self.features = features |
| self.max_len = max_len |
|
|
| def __getitem__(self, i): |
| feat_dict = self.features[i] |
| if self.max_len is not None and feat_dict['input_len'] > self.max_len: |
| |
| feat_dict['input_ids'] = feat_dict['input_ids'][-self.max_len:] |
| feat_dict['position_ids'] = feat_dict['position_ids'][ |
| -self.max_len:] |
| feat_dict['token_type_ids'] = feat_dict['token_type_ids'][ |
| -self.max_len:] |
| feat_dict['lm_labels'] = feat_dict['lm_labels'][-self.max_len:] |
| try: |
| for s in ['context_len', 'response_len']: |
| if s in feat_dict.keys(): |
| print("db file missing "+s) |
| del feat_dict[s] |
| except Exception: |
| import pdb |
| pdb.set_trace() |
|
|
| feat = InputFeatures_train(**feat_dict) |
| return feat |
|
|
| def __len__(self): |
| return len(self.features) |
|
|
| @staticmethod |
| def collate(features): |
| input_ids = pad_sequence([torch.tensor(f.input_ids, dtype=torch.long) |
| for f in features], |
| batch_first=True, padding_value=0) |
| position_ids = pad_sequence([torch.tensor(f.position_ids, |
| dtype=torch.long) |
| for f in features], |
| batch_first=True, padding_value=0) |
| token_type_ids = pad_sequence([torch.tensor(f.token_type_ids, |
| dtype=torch.long) |
| for f in features], |
| batch_first=True, padding_value=0) |
| labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long) |
| for f in features], |
| batch_first=True, padding_value=-1) |
| return (input_ids, position_ids, token_type_ids, labels) |
|
|
|
|
| class BucketingDataLoader(object): |
| """ this loads shelve db chunks and then convert to mini-batch loader""" |
| def __init__(self, db_name, batch_size, max_seq_len, |
| bucket=100, shuffle=True): |
| self.db = shelve.open(f'{db_name}/db', 'r') |
| self.batch_size = batch_size |
| self.max_len = max_seq_len |
| self.bucket_size = bucket * batch_size |
| self.shuffle = shuffle |
|
|
| def _get_keys(self): |
| keys = list(self.db.keys()) |
| return keys |
|
|
| def __iter__(self): |
| keys = self._get_keys() |
| if self.shuffle: |
| random.shuffle(keys) |
| for key in keys: |
| chunk = json.loads(gzip.decompress(self.db[key]).decode('utf-8')) |
| |
| trunc_chunk = [] |
| lens = [] |
| for feat in chunk: |
| if feat['input_len'] > self.max_len: |
| continue |
| trunc_chunk.append(feat) |
| lens.append(feat['input_len']) |
|
|
| dataset = GPT2FeatureDataset(trunc_chunk, self.max_len) |
| sampler = BucketSampler(lens, self.bucket_size, self.batch_size, |
| droplast=True, shuffle=self.shuffle) |
| loader = DataLoader(dataset, batch_sampler=sampler, |
| num_workers=0, |
| collate_fn=GPT2FeatureDataset.collate) |
| yield from loader |
|
|
| def __len__(self): |
| raise NotImplementedError() |
|
|
| def __del__(self): |
| self.db.close() |
|
|
|
|
| class DistributedBucketingDataLoader(BucketingDataLoader): |
| """ distributed version """ |
| def __init__(self, rank, num_replica, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.rank = rank |
| self.num_replica = num_replica |
|
|
| def _get_keys(self): |
| keys = list(self.db.keys())[self.rank::self.num_replica] |
| return keys |
|
|
|
|
| def convert_examples_to_features_dynamic(examples, tokenizer, |
| max_seq_length=512): |
| """ |
| do not pad |
| """ |
| def featurize(example): |
| conv_id = example.conv_id |
| context_id = tokenizer.encode(example.context) |
| end_of_text_id = tokenizer.encoder[END_OF_TEXT_TOKEN] |
|
|
| |
| response_id = tokenizer.encode(example.response) |
|
|
| input_ids_len = len(context_id) + len(response_id) + 2 |
| if input_ids_len > max_seq_length: |
| if len(context_id) > input_ids_len - max_seq_length: |
| |
| |
| context_id = context_id[input_ids_len - max_seq_length:] |
| else: |
| |
| |
| |
| if max_seq_length-len(context_id)-2 < 0: |
| return None |
| response_id = response_id[:max_seq_length-len(context_id)-2] |
|
|
| input_ids = context_id + [end_of_text_id] + response_id + [end_of_text_id] |
|
|
| |
| lm_labels = [-1] * len(context_id) + response_id + [end_of_text_id] + [-1] |
|
|
| position_ids = list(range(len(input_ids))) |
|
|
| token_type_id = [0] * len(input_ids) |
|
|
| return InputFeatures(conv_id, input_ids, position_ids, token_type_id, |
| lm_labels, len(context_id), len(response_id)) |
|
|
| |
| features = [f for f in [featurize(ex) for ex in examples] if f is not None] |
| return features |
|
|
|
|
| class DynamicBatchingLoader(object): |
| """ this loader takes raw text file, used for validate perplexity """ |
| def __init__(self, corpus_file, tokenizer, normalize_data, |
| batch_size, max_seq_length): |
| self.corpus = corpus_file |
| self.toker = tokenizer |
| self.norm = normalize_data |
| self.bs = batch_size |
| self.max_seq_length = max_seq_length |
| self.num_examples = self.get_len(corpus_file) |
|
|
| def __iter__(self, epoch=1): |
| if epoch > 0: |
| for epoch in range(epoch): |
| yield from self._iter_epoch() |
| else: |
| while True: |
| yield from self._iter_epoch() |
|
|
| def __len__(self): |
| return ceil(self.num_examples/self.bs) |
|
|
| def _iter_epoch(self): |
| try: |
| with open(self.corpus, 'r', encoding="utf-8") as corpus: |
| i = 0 |
| while True: |
| examples = [] |
| cur_bs = 0 |
| while True: |
| line = next(corpus).encode('utf-8').decode('utf-8') |
| contents = line.split('\t') |
| src, tgt_all = contents[0], contents[1:] |
| for tgt in tgt_all: |
| if self.norm: |
| src_line = ' '.join(src.strip().split()) |
| tgt_line = ' '.join(tgt.strip().split()) |
| else: |
| src_line = src.strip() |
| tgt_line = tgt.strip() |
| examples.append( |
| RedditExample(i, src_line, tgt_line), |
| ) |
| i += 1 |
| cur_bs += 1 |
| if cur_bs >= self.bs: |
| break |
| features = convert_examples_to_features_dynamic( |
| examples, self.toker, self.max_seq_length) |
| batch = self._batch_feature(features) |
| yield batch |
| except StopIteration: |
| pass |
|
|
| def _batch_feature(self, features): |
| input_ids = pad_sequence([torch.tensor(f.choices_features['input_ids'], |
| dtype=torch.long) |
| for f in features], |
| batch_first=True, padding_value=0) |
| position_ids = pad_sequence( |
| [torch.tensor(f.choices_features['position_ids'], dtype=torch.long) |
| for f in features], |
| batch_first=True, padding_value=0) |
| token_type_ids = pad_sequence( |
| [torch.tensor(f.choices_features['token_type_ids'], |
| dtype=torch.long) |
| for f in features], |
| batch_first=True, padding_value=0) |
| labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long) |
| for f in features], |
| batch_first=True, padding_value=-1) |
| context_len = torch.tensor([f.context_len for f in features], |
| dtype=torch.long) |
| response_len = torch.tensor([f.response_len for f in features], |
| dtype=torch.long) |
| return (input_ids, position_ids, token_type_ids, labels, |
| context_len, response_len) |
|
|
| def get_len(self, corpus): |
| n_line = int(sp.check_output(f"wc -l {corpus}".split(), |
| universal_newlines=True).split()[0]) |
| return n_line |
|
|