| from timm.models.layers import trunc_normal_ as __call_trunc_normal_ |
| from torchscale.component.multiway_network import MutliwayEmbedding |
| from torchscale.component.embedding import PositionalEmbedding |
| from torchscale.architecture.encoder import Encoder |
| from transformers import PreTrainedModel |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch |
| import math |
| from transformers import AutoModel |
| from transformers.utils.generic import ModelOutput |
| from dataclasses import dataclass |
| from typing import Optional |
| from efficientnet_pytorch import EfficientNet |
| from lavis.common.registry import registry |
| from .configuration_vivqa import ViVQAConfig |
|
|
| class BartPhoExtractor(nn.Module): |
| def __init__(self): |
| super(BartPhoExtractor, self).__init__() |
| self.bartpho_word = AutoModel.from_pretrained("vinai/bartpho-word") |
| |
| def forward(self, input_ids, attention_mask): |
| last_hidden_states = self.bartpho_word(input_ids, attention_mask) |
| features = last_hidden_states[0] |
| return features |
| |
| class Blip2EfficientExtractor(nn.Module): |
| def __init__(self): |
| super(Blip2EfficientExtractor, self).__init__() |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| |
| self.model_blip2 = registry.get_model_class(name="blip2_feature_extractor").from_pretrained(model_type="pretrain").to(self.device) |
| if self.device == "cpu" or self.device == torch.device("cpu"): |
| self.model_blip2 = self.model_blip2.float() |
| self.model_blip2.eval() |
| |
| |
| self.model_efficientnet = EfficientNet.from_pretrained('efficientnet-b7', advprop=True).to(self.device) |
| self.model_efficientnet.eval() |
| self.pooling1 = nn.AdaptiveAvgPool2d((1, 32)) |
| self.pooling2 = nn.AdaptiveAvgPool2d((1, 768)) |
| |
| def forward(self, images): |
|
|
| global_features = self.model_blip2.extract_features(samples={"image": images}, mode="image").image_embeds |
| |
| local_features = self.model_efficientnet.extract_features(images) |
| local_features = self.pooling1(local_features) |
| local_features = local_features.permute(0, 3, 2, 1) |
| local_features = self.pooling2(local_features) |
| batch_size = images.shape[0] |
| local_features = local_features.reshape(batch_size, local_features.shape[1], -1) |
| |
| v = torch.cat([global_features, local_features], dim=1) |
| return v |
| |
| @dataclass |
| class ViVQAOutput(ModelOutput): |
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| |
| def trunc_normal_(tensor, mean=0., std=1.): |
| __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) |
| |
| class Pooler(nn.Module): |
| def __init__(self, input_features, output_features, norm_layer): |
| super().__init__() |
| self.norm = norm_layer(input_features) |
| self.dense = nn.Linear(input_features, output_features) |
| self.activation = nn.Tanh() |
|
|
| def forward(self, x): |
| cls_rep = x[:, 0, :] |
| cls_rep = self.norm(cls_rep) |
| pooled_output = self.dense(cls_rep) |
| pooled_output = self.activation(pooled_output) |
| return pooled_output |
|
|
| class ViVQABEiT3(PreTrainedModel): |
| def __init__(self, args): |
| super().__init__(args) |
| assert args.multiway |
| assert not args.share_encoder_input_output_embed |
| |
| self.text_embed = BartPhoExtractor() |
| |
| self.vision_embed = Blip2EfficientExtractor() |
| for param in self.vision_embed.parameters(): |
| param.requires_grad = False |
|
|
| |
| self.linear = nn.Linear(1024, 768) |
| |
| |
| num_position_embeddings = 64 |
| embed_positions = MutliwayEmbedding( |
| modules=[ |
| PositionalEmbedding(num_position_embeddings + 2, args.encoder_embed_dim), |
| PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim), |
| ], |
| dim=1, |
| ) |
| self.encoder = Encoder( |
| args, |
| embed_tokens=None, |
| embed_positions=embed_positions, |
| output_projection=None, |
| is_encoder_decoder=False, |
| ) |
|
|
| def forward(self, textual_tokens, visual_tokens, text_padding_position): |
| x1 = self.vision_embed(visual_tokens) |
| multiway_split_position = x1.size(1) |
| |
| x2 = self.text_embed(textual_tokens, 1-text_padding_position) |
| x2 = self.linear(x2) |
| |
| x = torch.cat([x1, x2], dim=1) |
|
|
| encoder_padding_mask = torch.cat( |
| [ |
| torch.zeros(x1.shape[:-1]).to(x1.device).bool(), |
| text_padding_position, |
| ], |
| dim=1, |
| ) |
|
|
| encoder_out = self.encoder( |
| src_tokens=None, |
| encoder_padding_mask=encoder_padding_mask, |
| token_embeddings=x, |
| multiway_split_position=multiway_split_position |
| ) |
| encoder_out["multiway_split_position"] = multiway_split_position |
| return encoder_out |
| |
| class BEiT3Wrapper(PreTrainedModel): |
| def __init__(self, args, **kwargs): |
| super().__init__(args) |
| self.beit3 = ViVQABEiT3(args) |
| |
|
|
| def fix_init_weight(self): |
| def rescale(param, layer_id): |
| param.div_(math.sqrt(2.0 * layer_id)) |
|
|
| for layer_id, layer in enumerate(self.blocks): |
| rescale(layer.attn.proj.weight.data, layer_id + 1) |
| rescale(layer.mlp.fc2.weight.data, layer_id + 1) |
|
|
| def get_num_layers(self): |
| return self.beit3.encoder.num_layers |
| |
| @torch.jit.ignore |
| def no_weight_decay(self): |
| return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'} |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| trunc_normal_(m.weight, std=.02) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
| |
|
|
| class BEiT3ForVietnameseVisualQuestionAnswering(BEiT3Wrapper): |
| config_class = ViVQAConfig |
| def __init__( |
| self, |
| args, |
| num_classes=353, |
| **kwargs |
| ): |
| super(BEiT3ForVietnameseVisualQuestionAnswering, self).__init__(args=args) |
| embed_dim = args.encoder_embed_dim |
| self.pooler = Pooler( |
| input_features=embed_dim, |
| output_features=embed_dim, |
| norm_layer=nn.LayerNorm, |
| ) |
| self.pooler.apply(self._init_weights) |
| self.head = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim * 2), |
| nn.LayerNorm(embed_dim * 2), |
| nn.GELU(), |
| nn.Linear(embed_dim * 2, num_classes), |
| ) |
| self.head.apply(self._init_weights) |
|
|
| def forward(self, image, question, padding_mask, labels=None, **kwargs): |
| outputs = self.beit3( |
| textual_tokens=question, |
| visual_tokens=image, |
| text_padding_position=padding_mask, |
| ) |
| x = outputs["encoder_out"] |
| cls_rep = self.pooler(x) |
| logits = self.head(cls_rep) |
| |
| loss = None |
| if labels is not None: |
| loss = F.cross_entropy(logits, labels) |
| |
| return ViVQAOutput( |
| loss=loss, |
| logits=logits, |
| ) |