| |
| |
|
|
| |
|
|
| import os |
| import sys |
| import argparse |
| from tqdm import trange |
| from torchtext import data as torchtext_data |
| from torchtext import datasets |
|
|
| import torch |
| import torch.utils.data as data |
|
|
| from torchtext.vocab import Vectors, GloVe, CharNGram, FastText |
| from nltk.tokenize.treebank import TreebankWordDetokenizer |
| import torch |
| import torch.optim |
| import torch.nn.functional as F |
| import numpy as np |
| from IPython import embed |
| from operator import add |
| from run_gpt2 import top_k_logits |
| from style_utils import to_var |
| import copy |
| import pickle |
| from torch.utils.data import DataLoader |
| from torch.utils.data.dataset import random_split |
| import torch.optim as optim |
|
|
| torch.manual_seed(0) |
| np.random.seed(0) |
|
|
| lab_root = os.path.join(os.path.abspath(os.path.dirname(__file__)), '..', '..') |
| sys.path.insert(1, lab_root) |
|
|
| from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer |
| from torch.autograd import Variable |
|
|
| tokenizer = GPT2Tokenizer.from_pretrained('gpt-2_pt_models/345M/') |
|
|
| model = GPT2LMHeadModel.from_pretrained('gpt-2_pt_models/345M/') |
|
|
|
|
| class ClassificationHead(torch.nn.Module): |
| """ Language Model Head for the transformer """ |
|
|
| def __init__(self, class_size=5, embed_size=2048): |
| super(ClassificationHead, self).__init__() |
| self.class_size = class_size |
| self.embed_size = embed_size |
| |
| |
| self.mlp = (torch.nn.Linear(embed_size, class_size)) |
|
|
| def forward(self, hidden_state): |
| |
| |
| |
| |
| lm_logits = self.mlp(hidden_state) |
| return lm_logits |
|
|
|
|
| class Discriminator(torch.nn.Module): |
| def __init__(self): |
| super(Discriminator, self).__init__() |
| self.classifierhead = ClassificationHead() |
| self.model = model |
| self.spltoken = Variable(torch.randn(1, 1, 1024).type(torch.FloatTensor), requires_grad=True) |
| self.spltoken = self.spltoken.repeat(10, 1, 1) |
| self.spltoken = self.spltoken.cuda() |
|
|
| def train(self): |
| for param in self.model.parameters(): |
| param.requires_grad = False |
| pass |
|
|
| def forward(self, x): |
| x = model.forward_embed(x) |
| x = torch.cat((x, self.spltoken), dim=1) |
| _, x = model.forward_transformer_embed(x, add_one=True) |
| x = self.classifierhead(x[-1][:, -1, :]) |
| x = F.log_softmax(x, dim=-1) |
| return x |
|
|
|
|
| class Discriminator2(torch.nn.Module): |
| def __init__(self, class_size=5, embed_size=1024): |
| super(Discriminator2, self).__init__() |
| self.classifierhead = ClassificationHead(class_size=class_size, embed_size=embed_size) |
| self.model = model |
| self.embed_size = embed_size |
|
|
| def get_classifier(self): |
| return self.classifierhead |
|
|
| def train_custom(self): |
| for param in self.model.parameters(): |
| param.requires_grad = False |
| pass |
| self.classifierhead.train() |
|
|
| def forward(self, x): |
| x = model.forward_embed(x) |
| hidden, x = model.forward_transformer_embed(x) |
| x = torch.sum(hidden, dim=1) |
| x = self.classifierhead(x) |
| x = F.log_softmax(x, dim=-1) |
| return x |
|
|
| class Discriminator2mean(torch.nn.Module): |
| def __init__(self, class_size=5, embed_size=1024): |
| super(Discriminator2mean, self).__init__() |
| self.classifierhead = ClassificationHead(class_size=class_size, embed_size=embed_size) |
| self.model = model |
| self.embed_size = embed_size |
|
|
| def get_classifier(self): |
| return self.classifierhead |
|
|
| def train_custom(self): |
| for param in self.model.parameters(): |
| param.requires_grad = False |
| pass |
| self.classifierhead.train() |
|
|
| def forward(self, x): |
| mask_src = 1 - x.eq(0).unsqueeze(1).type(torch.FloatTensor).cuda().detach() |
| mask_src = mask_src.repeat(1, self.embed_size, 1) |
| x = model.forward_embed(x) |
| hidden, x = model.forward_transformer_embed(x) |
| |
|
|
| hidden = hidden.permute(0, 2, 1) |
| _, _, batch_length = hidden.shape |
| hidden = hidden * mask_src |
| |
| hidden = hidden.permute(0, 2, 1) |
| x = torch.sum(hidden, dim=1)/(torch.sum(mask_src, dim=-1).detach() + 1e-10) |
| x = self.classifierhead(x) |
| x = F.log_softmax(x, dim=-1) |
| return x |
|
|
| class Dataset(data.Dataset): |
| def __init__(self, X, y): |
| """Reads source and target sequences from txt files.""" |
| self.X = X |
| self.y = y |
|
|
| def __len__(self): |
| return len(self.X) |
|
|
| def __getitem__(self, index): |
| """Returns one data pair (source and target).""" |
| d = {} |
| d['X'] = self.X[index] |
| d['y'] = self.y[index] |
| return d |
|
|
|
|
| def collate_fn(data): |
| def merge(sequences): |
| lengths = [len(seq) for seq in sequences] |
|
|
| padded_seqs = torch.zeros(len(sequences), max(lengths)).long().cuda() |
| for i, seq in enumerate(sequences): |
| end = lengths[i] |
| padded_seqs[i, :end] = seq[:end] |
| return padded_seqs, lengths |
|
|
| data.sort(key=lambda x: len(x["X"]), reverse=True) |
|
|
| item_info = {} |
| for key in data[0].keys(): |
| item_info[key] = [d[key] for d in data] |
|
|
| |
| x_batch, _ = merge(item_info['X']) |
| y_batch = item_info['y'] |
|
|
| return x_batch, torch.tensor(y_batch, device='cuda', dtype=torch.long) |
|
|
|
|
| def train_epoch(data_loader, discriminator, device='cuda', args=None, epoch=1): |
| optimizer = optim.Adam(discriminator.parameters(), lr=0.0001) |
| discriminator.train_custom() |
|
|
| for batch_idx, (data, target) in enumerate(data_loader): |
| data, target = data.to(device), target.to(device) |
|
|
| optimizer.zero_grad() |
|
|
| output = discriminator(data) |
| loss = F.nll_loss(output, target) |
| loss.backward(retain_graph=True) |
| optimizer.step() |
|
|
| if batch_idx % args.log_interval == 0: |
| print('Relu Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( |
| epoch, batch_idx * len(data), len(data_loader.dataset), |
| 100. * batch_idx / len(data_loader), loss.item())) |
|
|
|
|
| def test_epoch(data_loader, discriminator, device='cuda', args=None): |
| discriminator.eval() |
| test_loss = 0 |
| correct = 0 |
| with torch.no_grad(): |
| for data, target in data_loader: |
| data, target = data.to(device), target.to(device) |
| output = discriminator(data) |
| test_loss += F.nll_loss(output, target, reduction='sum').item() |
| pred = output.argmax(dim=1, keepdim=True) |
| correct += pred.eq(target.view_as(pred)).sum().item() |
|
|
| test_loss /= len(data_loader.dataset) |
|
|
| print('\nRelu Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( |
| test_loss, correct, len(data_loader.dataset), |
| 100. * correct / len(data_loader.dataset))) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Train a discriminator on top of GPT-2 representations') |
| parser.add_argument('--batch-size', type=int, default=64, metavar='N', |
| help='input batch size for training (default: 64)') |
| parser.add_argument('--log-interval', type=int, default=10, metavar='N', |
| help='how many batches to wait before logging training status') |
| parser.add_argument('--epochs', type=int, default=10, metavar='N', |
| help='Number of training epochs') |
| parser.add_argument('--save-model', action='store_true', help='whether to save the model') |
| parser.add_argument('--dataset-label', type=str, default='SST',choices=('SST', 'clickbait', 'toxic')) |
| args = parser.parse_args() |
|
|
| batch_size = args.batch_size |
| device = 'cuda' |
| |
| if args.dataset_label == 'SST': |
| text = torchtext_data.Field() |
| label = torchtext_data.Field(sequential=False) |
| train_data, val_data, test_data = datasets.SST.splits(text, label, fine_grained=True, train_subtrees=True, |
| |
| ) |
| x = [] |
| y = [] |
| d = {"positive": 0, "negative": 1, "very positive": 2, "very negative": 3, "neutral": 4} |
|
|
| for i in range(len(train_data)): |
| seq = TreebankWordDetokenizer().detokenize(vars(train_data[i])["text"]) |
| seq = tokenizer.encode(seq) |
| seq = torch.tensor(seq, device=device, dtype=torch.long) |
| x.append(seq) |
| y.append(d[vars(train_data[i])["label"]]) |
|
|
| dataset = Dataset(x, y) |
|
|
| test_x = [] |
| test_y = [] |
| for i in range(len(test_data)): |
| seq = TreebankWordDetokenizer().detokenize(vars(test_data[i])["text"]) |
| seq = tokenizer.encode(seq) |
| seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) |
| test_x.append(seq) |
| test_y.append(d[vars(test_data[i])["label"]]) |
| test_dataset = Dataset(test_x, test_y) |
| discriminator = Discriminator2mean(class_size=5).to(device) |
|
|
| elif args.dataset_label == 'clickbait': |
| |
| with open("datasets/clickbait/clickbait_train_prefix.txt") as f: |
| data = [] |
| for d in f: |
| try: |
| data.append(eval(d)) |
| except: |
| continue |
| x = [] |
| y = [] |
| for d in data: |
| try: |
| |
| try: |
| seq = tokenizer.encode(d["text"]) |
| except: |
| continue |
| seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) |
| x.append(seq) |
| y.append(d['label']) |
| except: |
| pass |
|
|
| dataset = Dataset(x, y) |
| train_size = int(0.9 * len(dataset)) |
| test_size = len(dataset) - train_size |
| dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) |
| discriminator = Discriminator2mean(class_size=2).to(device) |
|
|
| elif args.dataset_label == 'toxic': |
| |
| with open("datasets/toxic/toxic_train.txt") as f: |
| data = [] |
| for d in f: |
| data.append(eval(d)) |
|
|
| x = [] |
| y = [] |
| for d in data: |
| try: |
| |
| seq = tokenizer.encode(d["text"]) |
|
|
| device = 'cuda' |
| if(len(seq)<100): |
| seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) |
| else: |
| continue |
| x.append(seq) |
| y.append(int(np.sum(d['label'])>0)) |
| except: |
| pass |
|
|
| dataset = Dataset(x, y) |
| print(dataset) |
| print(len(dataset)) |
| train_size = int(0.9 * len(dataset)) |
| test_size = len(dataset) - train_size |
| dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) |
| discriminator = Discriminator2mean(class_size=2).to(device) |
|
|
| data_loader = torch.utils.data.DataLoader(dataset=dataset, |
| batch_size=batch_size, |
| shuffle=True, collate_fn=collate_fn) |
| test_loader = torch.utils.data.DataLoader(dataset=test_dataset, |
| batch_size=batch_size, collate_fn=collate_fn) |
|
|
| for epoch in range(args.epochs): |
| train_epoch(discriminator=discriminator, data_loader=data_loader, args=args, device=device, epoch=epoch) |
| test_epoch(data_loader=test_loader, discriminator=discriminator, args=args) |
| seq = tokenizer.encode("This is incredible! I love it, this is the best chicken I have ever had.") |
| seq = torch.tensor([seq], device=device, dtype=torch.long) |
| print(discriminator(seq)) |
|
|
| if (args.save_model): |
| torch.save(discriminator.state_dict(), |
| "discrim_models/{}_mean_lin_discriminator_{}.pt".format(args.dataset_label, epoch)) |
| torch.save(discriminator.get_classifier().state_dict(), |
| "discrim_models/{}_classifierhead.pt".format(args.dataset_label)) |
|
|
| seq = tokenizer.encode("This is incredible! I love it, this is the best chicken I have ever had.") |
| seq = torch.tensor([seq], device=device, dtype=torch.long) |
| print(discriminator(seq)) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|
|
|