| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, AutoTokenizer |
| from .configuration_gpt2vision import GPT2VisionConfig |
| from .vision_encoder import VisionEncoder |
| from .modeling_gpt2 import GPT2LMHeadModel |
|
|
| IMAGE_TOKEN = "<image>" |
| ANSWER_EOS = "<|endoftext|>" |
|
|
| def resize_token_embeds(model_name="openai-community/gpt2"): |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| new_tokens = { |
| "additional_special_tokens": [IMAGE_TOKEN] |
| } |
| tokenizer.add_special_tokens(new_tokens) |
| return tokenizer |
|
|
| tokenizer = resize_token_embeds() |
|
|
| class MLP(nn.Module): |
| def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features) |
| self.act = nn.GELU(approximate="tanh") |
| self.fc2 = nn.Linear(hidden_features, out_features) |
| self.dropout = nn.Dropout(p=0.1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.dropout(x) |
| x = self.fc2(x) |
| return x |
|
|
| class GPT2Vision(PreTrainedModel): |
| config_class = GPT2VisionConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.vision_encoder = VisionEncoder() |
| self.mlp = MLP(in_features=768, hidden_features=768 * 4, out_features=768) |
| self.language_model = GPT2LMHeadModel(config.gpt2_config) |
| self.language_model.resize_token_embeddings(len(tokenizer)) |
| self.tokenizer = tokenizer |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) |
|
|
| @property |
| def device(self): |
| return next(self.language_model.parameters()).device |
|
|
| def preprocess_inputs(self, batch): |
| img_embs = batch['pixel_values'] |
| input_ids = batch['input_ids'] |
| attention_mask = batch['attention_mask'] |
| input_ids = input_ids.to(self.device) |
| attention_mask = attention_mask.to(self.device) |
| img_embs = img_embs.to(self.device) |
| |
| tok_embs = self.language_model.get_input_embeddings()(input_ids) |
| inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1) |
| img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device) |
| attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1) |
| return inputs_embeds, attention_mask, input_ids |
|
|
| def generate(self, question, image, max_new_tokens=30, **kwargs): |
| |
| |
| with torch.no_grad(): |
| img_features = self.vision_encoder(image,device=self.device) |
| img_embs = self.mlp(img_features) |
| |
| |
| prompt = f"{IMAGE_TOKEN}Question: {question}\nAnswer:" |
| encoded_input = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True,max_length=720) |
| |
| batch = { |
| "pixel_values": img_embs, |
| "input_ids": encoded_input.input_ids.to(self.device), |
| "attention_mask": encoded_input.attention_mask.to(self.device) |
| } |
|
|
| |
| inputs_embeds, attention_mask, input_ids = self.preprocess_inputs(batch) |
|
|
|
|
| |
| output_sequences = self.language_model.generate( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| pad_token_id=self.tokenizer.eos_token_id, |
| eos_token_id=self.tokenizer.eos_token_id, |
| max_new_tokens=max_new_tokens, |
| **kwargs |
| ) |
| |
| output = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True) |
| return output |