| |
| |
|
|
| |
| |
| import argparse |
| import csv |
| import json |
| import math |
| import numpy as np |
| import os |
| import time |
| import torch |
| import torch.nn.functional as F |
| import torch.optim |
| import torch.optim as optim |
| import torch.utils.data as data |
| from nltk.tokenize.treebank import TreebankWordDetokenizer |
| from torchtext import data as torchtext_data |
| from torchtext import datasets |
| from tqdm import tqdm, trange |
| from transformers import BertTokenizer, BertModel |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel |
| from transformers import GPT2ForSequenceClassification |
| from datasets import load_dataset |
|
|
| from pplm_classification_head import ClassificationHead |
|
|
| torch.manual_seed(0) |
| np.random.seed(0) |
| EPSILON = 1e-10 |
| example_sentence = "This is incredible! I love it, this is the best chicken I have ever had." |
| max_length_seq = 100 |
|
|
|
|
| class Discriminator(torch.nn.Module): |
| """Transformer encoder followed by a Classification Head""" |
|
|
| def __init__( |
| self, |
| class_size=None, |
| pretrained_model="gpt2-medium", |
| classifier_head=None, |
| cached_mode=False, |
| device='cpu', |
| fp=None, |
| is_deep=False, |
| is_deeper=False, |
| use_xlnet=False, |
| output_hidden_states=False, |
| unfreeze=False |
| ): |
| super(Discriminator, self).__init__() |
| self.use_xlnet = use_xlnet |
| if pretrained_model.startswith("gpt2") or pretrained_model.startswith("microsoft/DialoGPT"): |
| self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) |
| self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=output_hidden_states) |
| self.embed_size = self.encoder.transformer.config.hidden_size |
| elif pretrained_model.startswith("bert"): |
| self.tokenizer = BertTokenizer.from_pretrained(pretrained_model) |
| self.encoder = BertModel.from_pretrained(pretrained_model) |
| self.embed_size = self.encoder.config.hidden_size |
| else: |
| try: |
| self.tokenizer = GPT2Tokenizer.from_pretrained("microsoft/DialoGPT-large") |
| self.encoder = GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-large", output_hidden_states=output_hidden_states) |
| self.encoder.load_state_dict(torch.load(pretrained_model)) |
| self.embed_size = self.encoder.transformer.config.hidden_size |
| except: |
| raise ValueError( |
| "{} model not yet supported".format(pretrained_model) |
| ) |
| if classifier_head: |
| self.classifier_head = classifier_head |
| else: |
| if not class_size: |
| raise ValueError("must specify class_size") |
| self.classifier_head = ClassificationHead( |
| class_size=class_size, |
| embed_size=self.embed_size, |
| is_deep=is_deep, |
| is_deeper=is_deeper, |
| use_xlnet=use_xlnet |
| ) |
| if fp != None: |
| self.classifier_head.load_state_dict( |
| torch.load(fp, map_location=device)) |
| self.cached_mode = cached_mode |
| self.device = device |
| self.unfreeze = unfreeze |
|
|
| def get_classifier(self): |
| return self.classifier_head |
|
|
| def train_custom(self): |
| for param in self.encoder.parameters(): |
| param.requires_grad = self.unfreeze |
| self.classifier_head.train() |
|
|
| def avg_representation(self, x): |
| mask = x.ne(0).unsqueeze(2).repeat( |
| 1, 1, self.embed_size |
| ).float().to(self.device).detach() |
| if hasattr(self.encoder, 'transformer'): |
| |
| hidden, _ = self.encoder.transformer(x) |
| else: |
| |
| hidden, _ = self.encoder(x) |
| masked_hidden = hidden * mask |
| avg_hidden = torch.sum(masked_hidden, dim=1) / ( |
| torch.sum(mask, dim=1).detach() + EPSILON |
| ) |
| return avg_hidden |
|
|
| def forward(self, x): |
| if self.cached_mode: |
| avg_hidden = x.to(self.device) |
| else: |
| avg_hidden = self.avg_representation(x.to(self.device)) |
| if self.use_xlnet: |
| logits = self.classifier_head(None, inputs_embeds=avg_hidden.unsqueeze(dim=2)) |
| else: |
| logits = self.classifier_head(avg_hidden) |
| probs = F.log_softmax(logits, dim=-1) |
| avg_hidden, logits = avg_hidden.to("cpu"), logits.to("cpu") |
| return probs |
|
|
| def predict(self, input_sentence): |
| input_t = self.tokenizer.encode(input_sentence) |
| input_t = torch.tensor([input_t], dtype=torch.long, device=self.device) |
| if self.cached_mode: |
| input_t = self.avg_representation(input_t) |
|
|
| log_probs = self(input_t).data.cpu().numpy().flatten().tolist() |
| prob = [math.exp(log_prob) for log_prob in log_probs] |
| return prob |
|
|
|
|
| class Dataset(data.Dataset): |
| def __init__(self, X, y): |
| """Reads source and target sequences from txt files.""" |
| self.X = X |
| self.y = y |
|
|
| def __len__(self): |
| return len(self.X) |
|
|
| def __getitem__(self, index): |
| """Returns one data pair (source and target).""" |
| data = {} |
| data["X"] = self.X[index] |
| data["y"] = self.y[index] |
| return data |
|
|
|
|
| def collate_fn(data): |
| def pad_sequences(sequences): |
| lengths = [len(seq) for seq in sequences] |
|
|
| padded_sequences = torch.zeros( |
| len(sequences), |
| min(max(lengths), 512) |
| ).long() |
| |
| |
| errors = [] |
| for i, seq in enumerate(sequences): |
| end = min(lengths[i], 512) |
| padded_sequences[i, :end] = seq[-end:] |
| return padded_sequences, lengths |
|
|
| item_info = {} |
| for key in data[0].keys(): |
| item_info[key] = [d[key] for d in data] |
|
|
| x_batch, _ = pad_sequences(item_info["X"]) |
| y_batch = torch.tensor(item_info["y"], dtype=torch.long) |
|
|
| return x_batch, y_batch |
|
|
|
|
| def cached_collate_fn(data): |
| item_info = {} |
| for key in data[0].keys(): |
| item_info[key] = [d[key] for d in data] |
|
|
| x_batch = torch.cat(item_info["X"], 0) |
| y_batch = torch.tensor(item_info["y"], dtype=torch.long) |
|
|
| return x_batch, y_batch |
|
|
|
|
| def train_epoch(data_loader, discriminator, optimizer, |
| epoch=0, log_interval=10, device='cpu'): |
| samples_so_far = 0 |
| discriminator.train_custom() |
| for batch_idx, (input_t, target_t) in enumerate(data_loader): |
| input_t, target_t = input_t.to(device), target_t.to(device) |
| samples_so_far += len(input_t) |
| if input_t.size()[-1] > 225: continue |
| optimizer.zero_grad() |
|
|
| output_t = discriminator(input_t) |
| loss = F.nll_loss(output_t, target_t) |
| loss.backward(retain_graph=True) |
| optimizer.step() |
|
|
| if batch_idx % log_interval == 0: |
| print( |
| "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( |
| epoch + 1, |
| samples_so_far, len(data_loader.dataset), |
| 100 * samples_so_far / len(data_loader.dataset), loss.item() |
| ) |
| ) |
| input_t, target_t = input_t.to("cpu"), target_t.to("cpu") |
| output_t, loss = output_t.to("cpu"), loss.to("cpu") |
| del loss |
| del output_t |
| del input_t |
| del target_t |
|
|
|
|
| def evaluate_performance(data_loader, discriminator, device='cpu', check=False, classes=3): |
| discriminator.eval() |
| test_loss = 0 |
| correct_count = 0 |
| hist_len = {} |
| token_len = {} |
| label_len = {} |
| hist_cor = {} |
| token_cor = {} |
| label_cor = {} |
| comp_mat = [[0 for i in range(classes)] for j in range(classes)] |
| with torch.no_grad(): |
| for batch_idx, (input_t, target_t) in enumerate(data_loader): |
| try: |
| input_t, target_t = input_t.to(device), target_t.to(device) |
| output_t = discriminator(input_t) |
| |
| test_loss += F.nll_loss(output_t, target_t, reduction="sum").item() |
| |
| pred_t = output_t.argmax(dim=1, keepdim=True) |
| res = torch.squeeze(pred_t.eq(target_t.view_as(pred_t))) |
| for i, correct, in enumerate(res): |
| comp_mat[pred_t[i].item()][target_t[i].item()] += 1 |
| if not correct: |
| tmp = input_t[i].tolist() |
| curCount = tmp.count(50256) |
| hist_len[curCount] = hist_len.get(curCount, 0) + 1 |
| token_len[len(tmp)-tmp.count(0)] = token_len.get(len(tmp)-tmp.count(0), 0) + 1 |
| label_len[target_t[i].item()] = label_len.get(target_t[i].item(), 0) + 1 |
| else: |
| correct_count += 1 |
| tmp = input_t[i].tolist() |
| curCount = tmp.count(50256) |
| hist_cor[curCount] = hist_cor.get(curCount, 0) + 1 |
| token_cor[len(tmp)-tmp.count(0)] = token_cor.get(len(tmp)-tmp.count(0), 0) + 1 |
| label_cor[target_t[i].item()] = label_cor.get(target_t[i].item(), 0) + 1 |
| del input_t |
| del target_t |
| except: |
| continue |
| print(hist_len) |
| print(token_len) |
| print(label_len) |
| print(hist_cor) |
| print(token_cor) |
| print(label_cor) |
| print(comp_mat) |
| test_loss /= len(data_loader.dataset) |
| accuracy = correct_count / len(data_loader.dataset) |
|
|
| print( |
| "Performance on test set: " |
| "Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format( |
| test_loss, correct_count, len(data_loader.dataset), |
| 100. * accuracy |
| ) |
| ) |
|
|
| return test_loss, accuracy |
|
|
|
|
| def predict(input_sentence, model, classes, cached=False, device='cpu'): |
| input_t = model.tokenizer.encode(input_sentence) |
| input_t = torch.tensor([input_t], dtype=torch.long, device=device) |
| if cached: |
| input_t = model.avg_representation(input_t) |
|
|
| log_probs = model(input_t).data.cpu().numpy().flatten().tolist() |
| print("Input sentence:", input_sentence) |
| print("Predictions:", ", ".join( |
| "{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in |
| zip(classes, log_probs) |
| )) |
|
|
|
|
| def get_cached_data_loader(dataset, batch_size, discriminator, |
| shuffle=False, device='cpu'): |
| data_loader = torch.utils.data.DataLoader(dataset=dataset, |
| batch_size=batch_size, |
| collate_fn=collate_fn) |
|
|
| xs = [] |
| ys = [] |
| for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)): |
| with torch.no_grad(): |
| x = x.to(device) |
| avg_rep = discriminator.avg_representation(x).cpu().detach() |
| avg_rep_list = torch.unbind(avg_rep.unsqueeze(1)) |
| xs += avg_rep_list |
| ys += y.cpu().numpy().tolist() |
|
|
| data_loader = torch.utils.data.DataLoader( |
| dataset=Dataset(xs, ys), |
| batch_size=batch_size, |
| shuffle=shuffle, |
| collate_fn=cached_collate_fn) |
|
|
| return data_loader |
|
|
|
|
| def get_idx2class(dataset_fp): |
| classes = set() |
| with open(dataset_fp) as f: |
| csv_reader = csv.reader(f, delimiter="\t") |
| for row in tqdm(csv_reader, ascii=True): |
| if row: |
| classes.add(row[0]) |
|
|
| return sorted(classes) |
|
|
|
|
| def get_generic_dataset(dataset_fp, tokenizer, device, |
| idx2class=None, add_eos_token=False): |
| if not idx2class: |
| idx2class = get_idx2class(dataset_fp) |
| class2idx = {c: i for i, c in enumerate(idx2class)} |
|
|
| x = [] |
| y = [] |
| with open(dataset_fp) as f: |
| csv_reader = csv.reader(f, delimiter="\t") |
| for i, row in enumerate(tqdm(csv_reader, ascii=True)): |
| if row: |
| label = row[0] |
| text = row[1] |
|
|
| try: |
| seq = tokenizer.encode(text) |
| if (len(seq) < max_length_seq): |
| if add_eos_token: |
| seq = [50256] + seq |
| seq = torch.tensor( |
| seq, |
| device=device, |
| dtype=torch.long |
| ) |
|
|
| else: |
| print( |
| "Line {} is longer than maximum length {}".format( |
| i, max_length_seq |
| )) |
| continue |
|
|
| x.append(seq) |
| y.append(class2idx[label]) |
|
|
| except: |
| print("Error tokenizing line {}, skipping it".format(i)) |
| pass |
|
|
| return Dataset(x, y) |
|
|
|
|
| def train_discriminator( |
| dataset, |
| dataset_fp=None, |
| pretrained_model="gpt2-medium", |
| epochs=10, |
| learning_rate=0.0001, |
| weight_decay=0.0, |
| batch_size=64, |
| log_interval=10, |
| save_model=False, |
| cached=False, |
| no_cuda=False, |
| output_fp='.', |
| fp=None, |
| is_deep=False, |
| is_deeper=False, |
| use_xlnet=False, |
| unfreeze=False |
| ): |
| device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" |
| add_eos_token = pretrained_model.startswith("gpt2") |
|
|
| if save_model: |
| if not os.path.exists(output_fp): |
| os.makedirs(output_fp) |
| classifier_head_meta_fp = os.path.join( |
| output_fp, "{}_classifier_head_meta.json".format(dataset) |
| ) |
| classifier_head_fp_pattern = os.path.join( |
| output_fp, "{}_classifier_head_epoch".format(dataset) + "_{}.pt" |
| ) |
|
|
| print("Preprocessing {} dataset...".format(dataset)) |
| start = time.time() |
|
|
| if dataset == "SST": |
| idx2class = ["positive", "negative", "very positive", "very negative", |
| "neutral"] |
| class2idx = {c: i for i, c in enumerate(idx2class)} |
|
|
| discriminator = Discriminator( |
| class_size=len(idx2class), |
| pretrained_model=pretrained_model, |
| cached_mode=cached, |
| device=device, |
| fp=fp, |
| is_deep=is_deep, |
| is_deeper=is_deeper, |
| use_xlnet=use_xlnet, |
| unfreeze=unfreeze |
| ).to(device) |
|
|
| text = torchtext_data.Field() |
| label = torchtext_data.Field(sequential=False) |
| train_data, val_data, test_data = datasets.SST.splits( |
| text, |
| label, |
| fine_grained=True, |
| train_subtrees=True, |
| ) |
|
|
| x = [] |
| y = [] |
| for i in trange(len(train_data), ascii=True): |
| seq = TreebankWordDetokenizer().detokenize( |
| vars(train_data[i])["text"] |
| ) |
| seq = discriminator.tokenizer.encode(seq) |
| if add_eos_token: |
| seq = [50256] + seq |
| seq = torch.tensor(seq, device=device, dtype=torch.long) |
| x.append(seq) |
| y.append(class2idx[vars(train_data[i])["label"]]) |
| train_dataset = Dataset(x, y) |
|
|
| test_x = [] |
| test_y = [] |
| for i in trange(len(test_data), ascii=True): |
| seq = TreebankWordDetokenizer().detokenize( |
| vars(test_data[i])["text"] |
| ) |
| seq = discriminator.tokenizer.encode(seq) |
| if add_eos_token: |
| seq = [50256] + seq |
| seq = torch.tensor(seq, device=device, dtype=torch.long) |
| test_x.append(seq) |
| test_y.append(class2idx[vars(test_data[i])["label"]]) |
| test_dataset = Dataset(test_x, test_y) |
|
|
| discriminator_meta = { |
| "class_size": len(idx2class), |
| "embed_size": discriminator.embed_size, |
| "pretrained_model": pretrained_model, |
| "class_vocab": class2idx, |
| "default_class": 2, |
| } |
|
|
| elif dataset == "5_PerSoothe": |
| if dataset_fp is None: |
| raise ValueError("When generic dataset is selected, " |
| "dataset_fp needs to be specified aswell.") |
| idx2class = ["soothes", "improve", "neutral", "trouble", "worsens"] |
| class2idx = {c: i for i, c in enumerate(idx2class)} |
|
|
| discriminator = Discriminator( |
| class_size=len(idx2class), |
| pretrained_model=pretrained_model, |
| cached_mode=cached, |
| device=device, |
| fp=fp, |
| is_deep=is_deep, |
| is_deeper=is_deeper, |
| use_xlnet=use_xlnet, |
| unfreeze=unfreeze |
| ).to(device) |
|
|
| finetuning_data = load_dataset('csv', data_files=dataset_fp) |
| finetuning_data = finetuning_data["train"].train_test_split(test_size=0.1) |
|
|
| train_data = finetuning_data["train"] |
| val_data = finetuning_data["test"] |
| test_data = finetuning_data["test"] |
|
|
| x = [] |
| y = [] |
| for i in trange(len(train_data), ascii=True): |
| seq = train_data[i]["text"] |
| seq = discriminator.tokenizer.encode(seq) |
| if add_eos_token: |
| seq = [50256] + seq |
| seq = torch.tensor(seq, device=device, dtype=torch.long) |
| x.append(seq) |
| y.append(class2idx[train_data[i]["label"]]) |
| train_dataset = Dataset(x, y) |
|
|
| test_x = [] |
| test_y = [] |
| for i in trange(len(test_data), ascii=True): |
| seq = test_data[i]["text"] |
| seq = discriminator.tokenizer.encode(seq) |
| if add_eos_token: |
| seq = [50256] + seq |
| seq = torch.tensor(seq, device=device, dtype=torch.long) |
| test_x.append(seq) |
| test_y.append(class2idx[test_data[i]["label"]]) |
| test_dataset = Dataset(test_x, test_y) |
|
|
| discriminator_meta = { |
| "class_size": len(idx2class), |
| "embed_size": discriminator.embed_size, |
| "pretrained_model": pretrained_model, |
| "class_vocab": class2idx, |
| "default_class": 2, |
| } |
| |
| elif dataset == "3_PerSoothe": |
| if dataset_fp is None: |
| raise ValueError("When generic dataset is selected, " |
| "dataset_fp needs to be specified aswell.") |
| |
| idx2class = ["soothes", "neutral", "worsens"] |
| class2idx = {c: i for i, c in enumerate(idx2class)} |
|
|
| discriminator = Discriminator( |
| class_size=len(idx2class), |
| pretrained_model=pretrained_model, |
| cached_mode=cached, |
| device=device, |
| fp=fp, |
| is_deep=is_deep, |
| is_deeper=is_deeper, |
| use_xlnet=use_xlnet, |
| unfreeze=unfreeze |
| ).to(device) |
| |
| finetuning_data = load_dataset('csv', data_files=dataset_fp) |
| finetuning_data = finetuning_data["train"].train_test_split(test_size=0.1) |
|
|
| train_data = finetuning_data["train"] |
| val_data = finetuning_data["test"] |
| test_data = finetuning_data["test"] |
|
|
| x = [] |
| y = [] |
| for i in trange(len(train_data), ascii=True): |
| seq = train_data[i]["text"] |
| seq = discriminator.tokenizer.encode(seq) |
| if add_eos_token: |
| seq = [50256] + seq |
| seq = torch.tensor(seq, device="cpu", dtype=torch.long) |
| x.append(seq) |
| y.append(class2idx[train_data[i]["label"]]) |
| train_dataset = Dataset(x, y) |
|
|
| test_x = [] |
| test_y = [] |
| for i in trange(len(test_data), ascii=True): |
| seq = test_data[i]["text"] |
| seq = discriminator.tokenizer.encode(seq) |
| if add_eos_token: |
| seq = [50256] + seq |
| seq = torch.tensor(seq, device="cpu", dtype=torch.long) |
| test_x.append(seq) |
| test_y.append(class2idx[test_data[i]["label"]]) |
| test_dataset = Dataset(test_x, test_y) |
|
|
| discriminator_meta = { |
| "class_size": len(idx2class), |
| "embed_size": discriminator.embed_size, |
| "pretrained_model": pretrained_model, |
| "class_vocab": class2idx, |
| "default_class": 2, |
| } |
| elif dataset == "3_PerSoothe_min": |
| if dataset_fp is None: |
| raise ValueError("When generic dataset is selected, " |
| "dataset_fp needs to be specified aswell.") |
| |
| idx2class = ["soothes", "neutral", "worsens"] |
| class2idx = {c: i for i, c in enumerate(idx2class)} |
|
|
| discriminator = Discriminator( |
| class_size=len(idx2class), |
| pretrained_model=pretrained_model, |
| cached_mode=cached, |
| device=device, |
| fp=fp, |
| is_deep=is_deep, |
| is_deeper=is_deeper, |
| use_xlnet=use_xlnet, |
| unfreeze=unfreeze |
| ).to(device) |
| |
| finetuning_data = load_dataset('csv', data_files=dataset_fp) |
| finetuning_data = finetuning_data["train"].train_test_split(test_size=0.001) |
|
|
| train_data = finetuning_data["train"] |
| val_data = finetuning_data["test"] |
| test_data = finetuning_data["test"] |
|
|
| x = [] |
| y = [] |
| for i in trange(len(train_data), ascii=True): |
| seq = train_data[i]["text"] |
| seq = discriminator.tokenizer.encode(seq) |
| if add_eos_token: |
| seq = [50256] + seq |
| seq = torch.tensor(seq, device="cpu", dtype=torch.long) |
| x.append(seq) |
| y.append(class2idx[train_data[i]["label"]]) |
| train_dataset = Dataset(x, y) |
|
|
| test_x = [] |
| test_y = [] |
| for i in trange(len(test_data), ascii=True): |
| seq = test_data[i]["text"] |
| seq = discriminator.tokenizer.encode(seq) |
| if add_eos_token: |
| seq = [50256] + seq |
| seq = torch.tensor(seq, device="cpu", dtype=torch.long) |
| test_x.append(seq) |
| test_y.append(class2idx[test_data[i]["label"]]) |
| test_dataset = Dataset(test_x, test_y) |
|
|
| discriminator_meta = { |
| "class_size": len(idx2class), |
| "embed_size": discriminator.embed_size, |
| "pretrained_model": pretrained_model, |
| "class_vocab": class2idx, |
| "default_class": 2, |
| } |
| elif dataset == "2_PerSoothe": |
| if dataset_fp is None: |
| raise ValueError("When generic dataset is selected, " |
| "dataset_fp needs to be specified aswell.") |
| |
| idx2class = ["soothes", "neutral"] |
| class2idx = {c: i for i, c in enumerate(idx2class)} |
|
|
| discriminator = Discriminator( |
| class_size=len(idx2class), |
| pretrained_model=pretrained_model, |
| cached_mode=cached, |
| device=device, |
| fp=fp, |
| is_deep=is_deep, |
| is_deeper=is_deeper, |
| use_xlnet=use_xlnet, |
| unfreeze=unfreeze |
| ).to(device) |
| |
| finetuning_data = load_dataset('csv', data_files=dataset_fp) |
| finetuning_data = finetuning_data["train"].train_test_split(test_size=0.1) |
|
|
| train_data = finetuning_data["train"] |
| val_data = finetuning_data["test"] |
| test_data = finetuning_data["test"] |
|
|
| x = [] |
| y = [] |
| for i in trange(len(train_data), ascii=True): |
| seq = train_data[i]["text"] |
| seq = discriminator.tokenizer.encode(seq) |
| if add_eos_token: |
| seq = [50256] + seq |
| seq = torch.tensor(seq, device=device, dtype=torch.long) |
| x.append(seq) |
| y.append(class2idx[train_data[i]["label"]]) |
| train_dataset = Dataset(x, y) |
|
|
| test_x = [] |
| test_y = [] |
| for i in trange(len(test_data), ascii=True): |
| seq = test_data[i]["text"] |
| seq = discriminator.tokenizer.encode(seq) |
| if add_eos_token: |
| seq = [50256] + seq |
| seq = torch.tensor(seq, device=device, dtype=torch.long) |
| test_x.append(seq) |
| test_y.append(class2idx[test_data[i]["label"]]) |
| test_dataset = Dataset(test_x, test_y) |
|
|
| discriminator_meta = { |
| "class_size": len(idx2class), |
| "embed_size": discriminator.embed_size, |
| "pretrained_model": pretrained_model, |
| "class_vocab": class2idx, |
| "default_class": 2, |
| } |
| else: |
| |
| |
|
|
| if dataset_fp is None: |
| raise ValueError("When generic dataset is selected, " |
| "dataset_fp needs to be specified aswell.") |
|
|
| idx2class = get_idx2class(dataset_fp) |
|
|
| discriminator = Discriminator( |
| class_size=len(idx2class), |
| pretrained_model=pretrained_model, |
| cached_mode=cached, |
| device=device, |
| fp=fp, |
| is_deep=is_deep, |
| is_deeper=is_deeper, |
| use_xlnet=use_xlnet, |
| unfreeze=unfreeze |
| ).to(device) |
|
|
| full_dataset = get_generic_dataset( |
| dataset_fp, discriminator.tokenizer, device, |
| idx2class=idx2class, add_eos_token=add_eos_token |
| ) |
| train_size = int(0.9 * len(full_dataset)) |
| test_size = len(full_dataset) - train_size |
| train_dataset, test_dataset = torch.utils.data.random_split( |
| full_dataset, |
| [train_size, test_size] |
| ) |
|
|
| discriminator_meta = { |
| "class_size": len(idx2class), |
| "embed_size": discriminator.embed_size, |
| "pretrained_model": pretrained_model, |
| "class_vocab": {c: i for i, c in enumerate(idx2class)}, |
| "default_class": 0, |
| } |
|
|
| end = time.time() |
| print("Preprocessed {} data points".format( |
| len(train_dataset) + len(test_dataset)) |
| ) |
| print("Data preprocessing took: {:.3f}s".format(end - start)) |
|
|
| if cached: |
| print("Building representation cache...") |
|
|
| start = time.time() |
|
|
| train_loader = get_cached_data_loader( |
| train_dataset, batch_size, discriminator, |
| shuffle=True, device="cpu" |
| ) |
|
|
| test_loader = get_cached_data_loader( |
| test_dataset, batch_size, discriminator, device="cpu" |
| ) |
|
|
| end = time.time() |
| print("Building representation cache took: {:.3f}s".format(end - start)) |
|
|
| else: |
| train_loader = torch.utils.data.DataLoader(dataset=train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| collate_fn=collate_fn) |
| test_loader = torch.utils.data.DataLoader(dataset=test_dataset, |
| batch_size=batch_size, |
| collate_fn=collate_fn) |
|
|
| if save_model: |
| with open(classifier_head_meta_fp, "w") as meta_file: |
| json.dump(discriminator_meta, meta_file) |
|
|
| optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, weight_decay=weight_decay) |
|
|
| test_losses = [] |
| test_accuracies = [] |
|
|
| for epoch in range(epochs): |
|
|
| start = time.time() |
| print("\nEpoch", epoch + 1) |
|
|
| train_epoch( |
| discriminator=discriminator, |
| data_loader=train_loader, |
| optimizer=optimizer, |
| epoch=epoch, |
| log_interval=log_interval, |
| device=device |
| ) |
| test_loss, test_accuracy = evaluate_performance( |
| data_loader=test_loader, |
| discriminator=discriminator, |
| device=device |
| ) |
|
|
| end = time.time() |
| print("Epoch took: {:.3f}s".format(end - start)) |
|
|
| test_losses.append(test_loss) |
| test_accuracies.append(test_accuracy) |
|
|
| print("\nExample prediction") |
| predict(example_sentence, discriminator, idx2class, |
| cached=cached, device=device) |
| |
| if save_model: |
| |
| |
| |
| |
| torch.save(discriminator.get_classifier().state_dict(), |
| classifier_head_fp_pattern.format(epoch + 1)) |
| if save_model and unfreeze: |
| torch.save(discriminator.encoder.state_dict(), |
| classifier_head_fp_pattern.format(0)) |
| min_loss = float("inf") |
| min_loss_epoch = 0 |
| max_acc = 0.0 |
| max_acc_epoch = 0 |
| print("Test performance per epoch") |
| print("epoch\tloss\tacc") |
| for e, (loss, acc) in enumerate(zip(test_losses, test_accuracies)): |
| print("{}\t{}\t{}".format(e + 1, loss, acc)) |
| if loss < min_loss: |
| min_loss = loss |
| min_loss_epoch = e + 1 |
| if acc > max_acc: |
| max_acc = acc |
| max_acc_epoch = e + 1 |
| print("Min loss: {} - Epoch: {}".format(min_loss, min_loss_epoch)) |
| print("Max acc: {} - Epoch: {}".format(max_acc, max_acc_epoch)) |
|
|
| return discriminator, discriminator_meta |
|
|
|
|
| def load_classifier_head(weights_path, meta_path, device='cpu',is_deep=False,is_deeper=False): |
| with open(meta_path, 'r', encoding="utf8") as f: |
| meta_params = json.load(f) |
| classifier_head = ClassificationHead( |
| class_size=meta_params['class_size'], |
| embed_size=meta_params['embed_size'], |
| is_deep=is_deep, |
| is_deeper=is_deeper |
| ).to(device) |
| classifier_head.load_state_dict( |
| torch.load(weights_path, map_location=device)) |
| classifier_head.eval() |
| return classifier_head, meta_params |
|
|
|
|
| def load_discriminator(weights_path, meta_path, device='cpu',is_deep=False,is_deeper=False): |
| classifier_head, meta_param = load_classifier_head( |
| weights_path, meta_path, device, is_deep, is_deeper |
| ) |
| discriminator = Discriminator( |
| pretrained_model=meta_param['pretrained_model'], |
| classifier_head=classifier_head, |
| cached_mode=False, |
| device=device |
| ) |
| return discriminator, meta_param |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Train a discriminator on top of GPT-2 representations") |
| parser.add_argument("--dataset", type=str, default="SST", |
| choices=("SST", "generic", "5_PerSoothe", "3_PerSoothe", "3_PerSoothe_min", "2_PerSoothe"), |
| help="dataset to train the discriminator on." |
| "In case of generic, the dataset is expected" |
| "to be a TSBV file with structure: class \\t text") |
| parser.add_argument("--dataset_fp", type=str, default="", |
| help="File path of the dataset to use. " |
| "Needed only in case of generic datadset") |
| parser.add_argument("--pretrained_model", type=str, default="gpt2-medium", |
| help="Pretrained model to use as encoder") |
| parser.add_argument("--epochs", type=int, default=10, metavar="N", |
| help="Number of training epochs") |
| parser.add_argument("--learning_rate", type=float, default=0.0001, |
| help="Learnign rate") |
| parser.add_argument("--weight_decay", type=float, default=0.0, |
| help="Weight decay") |
| parser.add_argument("--batch_size", type=int, default=64, metavar="N", |
| help="input batch size for training (default: 64)") |
| parser.add_argument("--log_interval", type=int, default=10, metavar="N", |
| help="how many batches to wait before logging training status") |
| parser.add_argument("--save_model", action="store_true", |
| help="whether to save the model") |
| parser.add_argument("--cached", action="store_true", |
| help="whether to cache the input representations") |
| parser.add_argument("--no_cuda", action="store_true", |
| help="use to turn off cuda") |
| parser.add_argument("--output_fp", default=".", |
| help="path to save the output to") |
| parser.add_argument("--fp", type=str, default=None, help="pretrained discriminator") |
| parser.add_argument("--is_deep", action="store_true", |
| help="whether to use deep classifier") |
| parser.add_argument("--is_deeper", action="store_true", |
| help="whether to use deeper classifier") |
| parser.add_argument("--use_xlnet", action="store_true", |
| help="whether to use xlnet classifier") |
| parser.add_argument("--unfreeze", action="store_true", |
| help="whether to train encoder as well") |
| args = parser.parse_args() |
|
|
| train_discriminator(**(vars(args))) |
|
|