|
|
| import datasets |
| import pandas as pd |
| from PIL import Image |
| import multiprocessing as mp |
| from sklearn.model_selection import train_test_split |
|
|
| import torch |
| from torchvision import transforms |
| from torch.utils.data import Dataset |
|
|
| from transformers import Seq2SeqTrainer ,Seq2SeqTrainingArguments |
| from transformers import VisionEncoderDecoderModel , ViTFeatureExtractor |
| from transformers import AutoTokenizer , default_data_collator |
| import os |
| os.environ["WANDB_DISABLED"] = "true" |
| import torch_xla.core.xla_model as xm |
|
|
| dev = xm.xla_device() |
|
|
|
|
| if torch.cuda.is_available(): |
|
|
| device = torch.device("cuda") |
|
|
| print('There are %d GPU(s) available.' % torch.cuda.device_count()) |
|
|
| print('We will use the GPU:', torch.cuda.get_device_name(0)) |
|
|
| else: |
| print('No GPU available, using the CPU instead.') |
| device = torch.device("cpu") |
|
|
|
|
|
|
| |
| class config : |
| ENCODER = "google/vit-base-patch16-224" |
| DECODER = "gpt2" |
| TRAIN_BATCH_SIZE = 64 |
| VAL_BATCH_SIZE = 64 |
| VAL_EPOCHS = 1 |
| LR = 5e-5 |
| SEED = 42 |
| MAX_LEN = 128 |
| SUMMARY_LEN = 20 |
| WEIGHT_DECAY = 0.01 |
| MEAN = (0.485, 0.456, 0.406) |
| STD = (0.229, 0.224, 0.225) |
| TRAIN_PCT = 0.95 |
| NUM_WORKERS = mp.cpu_count() |
| EPOCHS = 1 |
| IMG_SIZE = (224,224) |
| LABEL_MASK = -100 |
| TOP_K = 10 |
| TOP_P = 0.95 |
|
|
|
|
| def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
| outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] |
| return outputs |
| AutoTokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens |
|
|
|
|
|
|
| rouge = datasets.load_metric("rouge") |
|
|
| def compute_metrics(pred): |
| labels_ids = pred.label_ids |
| pred_ids = pred.predictions |
|
|
| |
| pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
| labels_ids[labels_ids == -100] = tokenizer.pad_token_id |
| label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) |
|
|
| rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid |
|
|
| return { |
| "rouge2_precision": round(rouge_output.precision, 4), |
| "rouge2_recall": round(rouge_output.recall, 4), |
| "rouge2_fmeasure": round(rouge_output.fmeasure, 4), |
| } |
|
|
|
|
| feature_extractor = ViTFeatureExtractor.from_pretrained(config.ENCODER) |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.unk_token |
|
|
| transforms = transforms.Compose( |
| [ |
| |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.5, 0.5, 0.5], |
| std=[0.5, 0.5, 0.5], |
| ) |
| ] |
| ) |
|
|
|
|
|
|
| class ImgDataset(torch.utils.data.Dataset): |
| def __init__(self, df, root_dir, tokenizer, feature_extractor, transform): |
| self.df = df |
| self.transform = transform |
| self.root_dir = root_dir |
| self.tokenizer = tokenizer |
| self.feature_extractor = feature_extractor |
| self.max_length = 128 |
|
|
| def __len__(self, ): |
| return len(self.df) |
|
|
| def __getitem__(self, idx): |
| caption = self.df.tags.iloc[idx] |
| image = self.df.image_id.iloc[idx]+".jpg" |
| folder_name = str(self.df.folder_name.iloc[idx]) |
| img_path = os.path.join(os.path.join(self.root_dir, folder_name), image) |
| img = Image.open(img_path).convert("RGB") |
|
|
|
|
| img = self.transform(img) |
|
|
| |
| if img.min() < 0.0: |
| img = (img + 1.0) / 2.0 |
|
|
| pixel_values = self.feature_extractor(img, return_tensors="pt").pixel_values |
| captions = self.tokenizer(caption, |
| padding='max_length', |
| max_length=self.max_length, |
| truncation=True).input_ids |
| captions = [caption if caption != self.tokenizer.pad_token_id else -100 for caption in captions] |
| encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(captions)} |
| return encoding |
|
|
| for j in range(1, 179+1): |
| df=pd.read_csv(rf"posts/posts-2023-04-17_MD5_caption_sifted_no_symbol_purged_folder_{j}.csv") |
| train_df , val_df = train_test_split(df , test_size = 0.02) |
| print(df.head(3)) |
|
|
| train_dataset = ImgDataset( |
| train_df, |
| root_dir = rf"dump_small", |
| tokenizer=tokenizer, |
| feature_extractor = feature_extractor , |
| transform = transforms, |
| ) |
|
|
| val_dataset = ImgDataset( |
| val_df , |
| root_dir = rf"dump_small", |
| tokenizer=tokenizer, |
| feature_extractor = feature_extractor , |
| transform = transforms |
| ) |
|
|
| if os.path.exists('VIT_large_gpt2_model'): |
| model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained('VIT_large_gpt2_model') |
| else: |
| model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(config.ENCODER, config.DECODER) |
|
|
|
|
| model.config.decoder_start_token_id = tokenizer.cls_token_id |
| model.config.pad_token_id = tokenizer.pad_token_id |
| |
| model.config.vocab_size = model.config.decoder.vocab_size |
| |
| model.config.eos_token_id = tokenizer.sep_token_id |
| model.config.decoder_start_token_id = tokenizer.bos_token_id |
| model.config.max_length = 128 |
| model.config.early_stopping = True |
| model.config.no_repeat_ngram_size = 2 |
| model.config.length_penalty = 2.0 |
| model.config.num_beams = 2 |
|
|
| training_args = Seq2SeqTrainingArguments( |
| output_dir='VIT_large_gpt2', |
| per_device_train_batch_size=config.TRAIN_BATCH_SIZE, |
| per_device_eval_batch_size=config.VAL_BATCH_SIZE, |
| predict_with_generate=True, |
| evaluation_strategy="steps", |
| do_train=True, |
| do_eval=True, |
| logging_steps=1000, |
| save_steps=1000, |
| warmup_steps=200, |
| learning_rate = 5e-5-j*2.2e-7, |
| |
| num_train_epochs = config.EPOCHS, |
| overwrite_output_dir=True, |
| save_total_limit=3, |
| ) |
|
|
|
|
|
|
|
|
| """import transformers.trainer |
| from transformers.trainer import SequentialSampler |
| |
| |
| def sampler_monkey_patch(dataset, generator): |
| return SequentialSampler(dataset) |
| |
| |
| transformers.trainer.RandomSampler = sampler_monkey_patch""" |
|
|
| trainer = Seq2SeqTrainer( |
| tokenizer=feature_extractor, |
| model=model, |
| args=training_args, |
| compute_metrics=compute_metrics, |
| train_dataset=train_dataset, |
| eval_dataset=val_dataset, |
| data_collator=default_data_collator, |
| ) |
| try: |
| trainer.train(resume_from_checkpoint='VIT_large_gpt2_model') |
| except: |
| trainer.train() |
| trainer.save_model('VIT_large_gpt2_model') |
|
|
|
|