| import torch |
| torch.manual_seed(1024) |
|
|
| import torch.nn as nn |
| from transformers import PreTrainedModel |
|
|
| from .configuration_hformer import HformerConfig |
| from .qformer_src import BertConfig, BertLMHeadModel |
|
|
| from transformers import BertTokenizerFast as BertTokenizer |
|
|
| from .configuration_projector import ProjectorConfig |
| from .modeling_projector import ProjectorModel |
| import torch.nn.functional as F |
| from transformers.activations import ACT2FN |
|
|
|
|
| class LayerNorm(nn.LayerNorm): |
| def forward(self, x: torch.Tensor): |
| ret = super().forward(x) |
| return ret |
|
|
| class HformerModel(PreTrainedModel): |
| _auto_class = 'AutoModel' |
| config_class = HformerConfig |
| base_model_prefix = 'model' |
| supports_gradient_checkpointing = False |
|
|
| def __init__(self, config) -> None: |
| super().__init__(config) |
| self.gradient_checkpointing = False |
| vision_width = config.visual_hidden_size |
| num_query_token = config.num_query_token |
| bert = config.bert |
| llm_hidden_size = config.llm_hidden_size |
| cross_attention_freq = config.cross_attention_freq |
| qformer_pth = config.qformer_pth |
|
|
| encoder_config = BertConfig.from_pretrained(bert) |
| encoder_config.encoder_width = vision_width |
| encoder_config.add_cross_attention = True |
| encoder_config.cross_attention_freq = cross_attention_freq |
| encoder_config.query_length = num_query_token |
| encoder_config.num_hidden_layers = 12 |
| Qformer = BertLMHeadModel.from_pretrained( |
| bert, config=encoder_config |
| ) |
| remove_text = False |
| if remove_text: |
| Qformer.cls = None |
| Qformer.bert.embeddings.word_embeddings = None |
| Qformer.bert.embeddings.position_embeddings = None |
| for layer in Qformer.bert.encoder.layer: |
| layer.output = None |
| layer.intermediate = None |
|
|
| query_tokens = nn.Parameter( |
| torch.zeros(1, num_query_token, encoder_config.hidden_size) |
| ) |
| query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) |
| |
| self.Qformer = Qformer |
| self.query_tokens = query_tokens |
| self.llm_proj = nn.Linear(encoder_config.hidden_size, llm_hidden_size, bias=config.bias) |
| self.ln_vision = LayerNorm(encoder_config.encoder_width) |
| self.ln_llava = LayerNorm(encoder_config.encoder_width) |
| |
| tokenizer = BertTokenizer.from_pretrained(bert, truncation_side='right') |
| tokenizer.add_special_tokens({"bos_token": "[DEC]"}) |
| self.Qformer.resize_token_embeddings(len(tokenizer)) |
|
|
| if qformer_pth is not None: |
| pretrained_state_dict = torch.load(qformer_pth, map_location='cpu')['model'] |
| print(f'Load Qformer from {qformer_pth}') |
| self.load_state_dict(pretrained_state_dict, strict=False) |
| print('Done.') |
|
|
| projector_config = ProjectorConfig( |
| visual_hidden_size = config.visual_hidden_size, |
| llm_hidden_size = config.llm_hidden_size, |
| projector_depth = 2) |
| self.connector = ProjectorModel(projector_config) |
|
|
| modules = [ |
| nn.Linear(config.llm_hidden_size, config.llm_hidden_size//4, bias=False), |
| ACT2FN['gelu'], |
| nn.Linear(config.llm_hidden_size//4, config.llm_hidden_size, bias=False) |
| ] |
| self.ffn = nn.Sequential(*modules) |
|
|
| def enable_input_require_grads(self): |
| def make_inputs_require_grad(module, input, output): |
| if isinstance(output, tuple): |
| output[0].requires_grad_(True) |
| output[1].requires_grad_(True) |
| else: |
| output.requires_grad_(True) |
|
|
| self.Qformer.register_forward_hook(make_inputs_require_grad) |
| self.llm_proj.register_forward_hook(make_inputs_require_grad) |
| self.ln_vision.register_forward_hook(make_inputs_require_grad) |
| self.connector.register_forward_hook(make_inputs_require_grad) |
| self.ffn.register_forward_hook(make_inputs_require_grad) |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| pass |
|
|
| def forward(self, x_): |
| if self.gradient_checkpointing and self.training: |
| print('Not support gradient checkpointing') |
| x = self.ln_vision(x_) |
| query_tokens = self.query_tokens.expand(x.shape[0], -1, -1) |
| query_output = self.Qformer.bert( |
| query_embeds=query_tokens, |
| encoder_hidden_states=x, |
| return_dict=True, |
| ) |
| |
| q_feat = self.llm_proj(query_output.last_hidden_state) |
| |
| mlp_outputs = self.connector(x_) |
| mlp_feat = mlp_outputs |
|
|
| int_feat = mlp_feat + q_feat.mean(dim=1)[:,None] |
| out = int_feat + self.ffn(int_feat) |
|
|
| return out |
|
|
|
|