import os import logging import torch import torch.nn as nn import lightning import torchmetrics import time from pathlib import Path as path from torch.nn import functional as F from torch.utils.data import Dataset, DataLoader from lightning import Fabric from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import CSVLogger from transformers import ( AutoModel, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, ) # from config import CustomConfig class ClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) classifier_dropout = 0.1 # classifier_dropout = ( # config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob # ) self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, features, **kwargs): x = features[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) x = torch.tanh(x) x = self.dropout(x) x = self.out_proj(x) return x class CustomModel(nn.Module): def __init__(self, model_name, pretrained_model_fold='./pretrained_model', share_encoder=False, ): super().__init__() self.model_name = model_name self.pretrained_model_fold = pretrained_model_fold self.share_encoder = share_encoder self.model_config = AutoConfig.from_pretrained(model_name, num_labels=2, cache_dir=pretrained_model_fold) self.encoder = AutoModel.from_config(self.model_config) if share_encoder: self.decoder_list = nn.ModuleList([ClassificationHead(self.model_config)for _ in range(3)]) else: self.decoder = ClassificationHead(self.model_config) # self.model = AutoModelForSequenceClassification.from_pretrained(config.model_name, num_labels=2) def get_pretrained_encoder(self): logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) # logging.getLogger("transformers").setLevel(logging.ERROR) cache_dir = self.pretrained_model_fold path(cache_dir).mkdir(parents=True, exist_ok=True) self.encoder = AutoModel.from_pretrained(self.model_name, cache_dir=cache_dir) def freeze_encoder(self): for param in self.encoder.parameters(): param.requires_grad = False def forward(self, batch_x): feature = self.encoder(**batch_x) feature = feature['last_hidden_state'] # feature = feature[0] if self.share_encoder: logits_list = [decoder(feature)for decoder in self.decoder_list] # cls(3), bsz, 2 prob_list = [F.softmax(logits, dim=-1)for logits in logits_list] # cls, bsz, 2 return torch.stack(prob_list, dim=0) # cls, bsz, 2 else: logits = self.decoder(feature) # bsz, 2 prob = F.softmax(logits, dim=-1) # bsz, 2 return prob def predict(self, batch_x): output = self(batch_x) # cls, bsz, 2 or bsz, 2 preds = torch.argmax(output, dim=-1) # cls, bsz or bsz return preds class Modelv2(lightning.LightningModule): def __init__(self, model_name='bert-base-uncased', pretrained_model_fold='./pretrained_model', share_encoder=False, rdrop=None, early_dropout=None, optimizer=torch.optim.AdamW, lr=5e-5, criterion=nn.CrossEntropyLoss(), ): super().__init__() self.model_name = model_name self.pretrained_model_fold = pretrained_model_fold self.share_encoder = share_encoder self.rdrop = rdrop self.early_dropout = early_dropout self.model_config = AutoConfig.from_pretrained(model_name, num_labels=2, cache_dir=pretrained_model_fold) self.encoder = AutoModel.from_config(self.model_config) if share_encoder: self.decoder_list = nn.ModuleList([ClassificationHead(self.model_config)for _ in range(3)]) else: self.decoder = ClassificationHead(self.model_config) self.optimizer = optimizer self.lr = lr self.criterion = criterion self.metric_name_list = ['accuracy', 'precision', 'recall', 'f1'] if self.share_encoder: self.train_metric_list = [ [ torchmetrics.Accuracy('binary'), torchmetrics.Precision('binary'), torchmetrics.Recall('binary'), torchmetrics.F1Score('binary') ] for _ in range(3) ] self.val_metric_list = [ [ torchmetrics.Accuracy('binary'), torchmetrics.Precision('binary'), torchmetrics.Recall('binary'), torchmetrics.F1Score('binary') ] for _ in range(3) ] self.test_metric_list = [ [ torchmetrics.Accuracy('binary'), torchmetrics.Precision('binary'), torchmetrics.Recall('binary'), torchmetrics.F1Score('binary') ] for _ in range(3) ] else: self.train_metric_list = [ torchmetrics.Accuracy('binary'), torchmetrics.Precision('binary'), torchmetrics.Recall('binary'), torchmetrics.F1Score('binary') ] self.val_metric_list = [ torchmetrics.Accuracy('binary'), torchmetrics.Precision('binary'), torchmetrics.Recall('binary'), torchmetrics.F1Score('binary') ] self.test_metric_list = [ torchmetrics.Accuracy('binary'), torchmetrics.Precision('binary'), torchmetrics.Recall('binary'), torchmetrics.F1Score('binary') ] def recurse_moduleList(lst): lst = [recurse_moduleList(p) if type(p) == list else p for p in lst] return nn.ModuleList(lst) self.train_metric_list = recurse_moduleList(self.train_metric_list) self.val_metric_list = recurse_moduleList(self.val_metric_list) self.test_metric_list = recurse_moduleList(self.test_metric_list) def get_pretrained_encoder(self): logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) # logging.getLogger("transformers").setLevel(logging.ERROR) cache_dir = self.pretrained_model_fold path(cache_dir).mkdir(parents=True, exist_ok=True) self.encoder = AutoModel.from_pretrained(self.model_name, cache_dir=cache_dir) def freeze_encoder(self): for param in self.encoder.parameters(): param.requires_grad = False def forward(self, batch_x): feature = self.encoder(**batch_x) feature = feature['last_hidden_state'] # feature = feature[0] if self.share_encoder: logits_list = [decoder(feature)for decoder in self.decoder_list] # cls(3), bsz, 2 return torch.stack(logits_list, dim=0) # prob_list = [F.softmax(logits, dim=-1)for logits in logits_list] # cls, bsz, 2 # return torch.stack(prob_list, dim=0) # cls, bsz, 2 else: logits = self.decoder(feature) # bsz, 2 return logits # prob = F.softmax(logits, dim=-1) # bsz, 2 # return prob def predict(self, batch_x): output = self(batch_x) # cls, bsz, 2 or bsz, 2 preds = torch.argmax(output, dim=-1) # cls, bsz or bsz return preds def predict_prob(self, batch_x): output = self(batch_x) probs = torch.softmax(output, dim=-1) probs = probs[..., 1] return probs def one_step(self, batch, stage): xs, ys = batch if self.rdrop == None: logits = self(xs) loss = self.criterion(logits.view(-1,2), ys.view(-1)) else: logits1 = self(xs) logits2 = self(xs) logits = logits1 ce_loss1 = self.criterion(logits1.view(-1,2), ys.view(-1)) ce_loss2 = self.criterion(logits2.view(-1,2), ys.view(-1)) kl_loss1 = F.kl_div(F.log_softmax(logits1, dim=-1), F.softmax(logits2, dim=-1), reduction='mean') kl_loss2 = F.kl_div(F.log_softmax(logits2, dim=-1), F.softmax(logits1, dim=-1), reduction='mean') loss = (ce_loss1+ce_loss2)/2 + self.rdrop * (kl_loss1+kl_loss2)/2 self.log(f'{stage}_loss', loss) with torch.no_grad(): preds = torch.argmax(logits, -1) metric_list = getattr(self, f'{stage}_metric_list') if self.share_encoder: macro_f1 = 0 for p in range(3): for metric_name, metric in zip(self.metric_name_list, metric_list[p]): metric(preds[p], ys[p]) self.log(f'{stage}_{metric_name}_{p}', metric, on_epoch=True, on_step=False) macro_f1 += metric_list[p][-1].compute() macro_f1 /= 3 self.log(f'{stage}_macro_f1', macro_f1, on_epoch=True, on_step=False) else: for metric_name, metric in zip(self.metric_name_list, metric_list): metric(preds, ys) self.log(f'{stage}_{metric_name}', metric, on_epoch=True, on_step=False) self.log(f'{stage}_macro_f1', metric_list[-1], on_epoch=True, on_step=False) return loss def on_train_epoch_start(self) -> None: # print(self.current_epoch) if self.early_dropout == None: return if self.current_epoch == self.early_dropout: for name, module in self.named_modules(): if isinstance(module, nn.Dropout): module.p = 0 def training_step(self, batch, batch_idx): return self.one_step(batch, 'train') def validation_step(self, batch, batch_idx): self.one_step(batch, 'val') def test_step(self, batch, batch_idx): self.one_step(batch, 'test') def configure_optimizers(self): params_list = [{'params':self.encoder.parameters()}] if self.share_encoder: params_list.append({'params': self.decoder_list.parameters()}) else: params_list.append({'params': self.decoder.parameters()}) return self.optimizer(params_list, self.lr) if __name__ == '__main__': class SampleDataset(Dataset): def __init__(self, model_name, pretrained_model_fold='./pretrained_model', share_encoder=False) -> None: super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=pretrained_model_fold) self.data = [ 'a sample sentence', 'two sample sentences', 'three sample sentences', 'four sample sentences '*3, # '谢谢关注', ]*10 self.share_encoder = share_encoder def __len__(self): return len(self.data) def __getitem__(self, index): if self.share_encoder: return self.data[index], (index%2,)*3 else: return self.data[index], index%2 def collate_fn(self, batch_data): xs, ys = zip(*batch_data) xs = self.tokenizer(xs, padding=True, truncation=True, return_tensors='pt') ys = torch.tensor(ys) if self.share_encoder: ys = ys.reshape((3,-1)) return xs, ys sample_model_name = 'bert-base-uncased' # sample_model_name = 'distilBert-base' sample_pretrained_model_fold = './pretrained_model' sample_share_encoder = True devices = [4] def sample_model_forward(): print('--- start testing') start_time = time.time() cur_time = time.time() sample_data = SampleDataset(sample_model_name, sample_pretrained_model_fold, sample_share_encoder) sample_data = DataLoader(sample_data, batch_size=5, collate_fn=sample_data.collate_fn) print(f'prepare data cost {time.time()-cur_time:.2f}s') cur_time = time.time() sample_model = Modelv2( sample_model_name, sample_pretrained_model_fold, share_encoder=sample_share_encoder, ) print(f'prepare model cost {time.time()-cur_time:.2f}s') cur_time = time.time() sample_model.get_pretrained_encoder() # sample_model.freeze_encoder() print(f'load model cost {time.time()-cur_time:.2f}s') cur_time = time.time() fab = Fabric(accelerator='cuda',devices=devices,precision='16-mixed') fab.launch() sample_model_fab = fab.setup_module(sample_model) sample_data_fab = fab.setup_dataloaders(sample_data) fab.barrier() print(f'prepare fabric cost {time.time()-cur_time:.2f}s') cur_time = time.time() for sample_x, sample_y in sample_data_fab: print('x') # print(sample_x) print(sample_x['input_ids'].shape) print('y') # print(sample_y) print(sample_y.shape) sample_output = sample_model_fab(sample_x) print('output') # print(sample_output) print(sample_output.shape) break print(f'deal one item cost {time.time()-cur_time:.2f}s') print(f'total cost {time.time()-start_time:.2f}s') def sample_train_test(): sample_data = SampleDataset(sample_model_name, sample_pretrained_model_fold, sample_share_encoder) sample_data = DataLoader(sample_data, batch_size=5, collate_fn=sample_data.collate_fn) sample_model = Modelv2( sample_model_name, sample_pretrained_model_fold, share_encoder=sample_share_encoder, ) sample_model.get_pretrained_encoder() sample_callbacks = [ModelCheckpoint( dirpath='logs/sample_ckpt/', filename='{epoch}-{val_macro_f1:.2f}', monitor='val_macro_f1', save_top_k=3, mode='max', )] sample_logger = CSVLogger(save_dir='logs', name='sample-log', version=10) sample_logger.log_hyperparams({'lr': 10, 'version': 'sample'}) sample_trainer = lightning.Trainer( max_epochs=5, callbacks=sample_callbacks, accelerator='gpu', devices=devices, logger=sample_logger, log_every_n_steps=10, # deterministic=True, precision='16-mixed', # strategy='deepspeed_stage_2' ) sample_trainer.fit( model=sample_model, train_dataloaders=sample_data, val_dataloaders=sample_data, ) # sample_ckpt = torch.load('./logs/sample_ckpt/epoch=0-step=8-v1.ckpt') # sample_lightning_model.load_state_dict(sample_ckpt['state_dict']) sample_trainer.test( model=sample_model, dataloaders=sample_data, ckpt_path='best' ) def sample_load_ckpt(): sample_data = SampleDataset(sample_model_name, sample_pretrained_model_fold, sample_share_encoder) sample_data = DataLoader(sample_data, batch_size=5, collate_fn=sample_data.collate_fn) sample_model = Modelv2( sample_model_name, sample_pretrained_model_fold, share_encoder=sample_share_encoder ) sample_ckpt_file = './logs/sample_ckpt/epoch=0-val_macro_f1=1.00.ckpt' # sample_model: lightning.LightningModule sample_model.load_from_checkpoint(sample_ckpt_file) fab = Fabric(accelerator='cuda',devices=devices,precision='16') fab.launch() sample_model_fab = fab.setup_module(sample_model) sample_data_fab = fab.setup_dataloaders(sample_data) fab.barrier() for sample_x, sample_y in sample_data_fab: print('x') # print(sample_x) print(sample_x['input_ids'].shape) print('y') # print(sample_y) print(sample_y.shape) sample_output = sample_model_fab(sample_x) print('output') # print(sample_output) print(sample_output.shape) break sample_model_forward() print('-'*20) sample_train_test() print('-'*20) sample_load_ckpt() print('-'*20) pass