| """PyTorch MarkupDM model.""" |
|
|
| import contextlib |
| import math |
| import os |
| from typing import Any |
|
|
| import rff.layers |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import ( |
| AutoModel, |
| AutoModelForCausalLM, |
| GenerationMixin, |
| PreTrainedModel, |
| ) |
| from transformers.loss.loss_utils import LOSS_MAPPING |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.utils import logging |
|
|
| from .configuration_markupdm import MarkupDMConfig |
| from .loss_utils import WeightedCausalLMLoss |
|
|
| logger = logging.get_logger(__name__) |
|
|
| LOSS_MAPPING["WeightedCausalLMLoss"] = WeightedCausalLMLoss |
|
|
|
|
| class MarkupDMForCausalLM(PreTrainedModel, GenerationMixin): |
| config: MarkupDMConfig |
| config_class = MarkupDMConfig |
|
|
| supports_gradient_checkpointing = True |
| _supports_flash_attn_2 = True |
|
|
| def __init__( |
| self, |
| config: MarkupDMConfig, |
| text_model: PreTrainedModel, |
| vision_model: PreTrainedModel, |
| ) -> None: |
| if not isinstance(config, self.config_class): |
| raise ValueError(f"Config: {config} has to be of type {self.config_class}") |
|
|
| |
| logger.info(f"MarkupDM config: {config}") |
| super().__init__(config) |
|
|
| self.text_model = text_model.train() |
| self.vision_model = vision_model.eval().requires_grad_(False) |
|
|
| if self.text_model.config.to_dict() != self.config.text_model.to_dict(): |
| logger.warning( |
| f"Config of the text model: {self.text_model.__class__} is" |
| f"overwritten by shared text config: {self.config.text_model}" |
| ) |
| if self.vision_model.config.to_dict() != self.config.vision_model.to_dict(): |
| logger.warning( |
| f"Config of the vision model: {self.vision_model.__class__} is" |
| f"overwritten by shared vision config: {self.config.vision_model}" |
| ) |
|
|
| |
| |
| self.text_model.config = self.config.text_model |
| self.vision_model.config = self.config.vision_model |
|
|
| |
| base_size = self.text_model.config.vocab_size |
| if base_size < self.config.vocab_size: |
| self.text_model.resize_token_embeddings(self.config.vocab_size) |
| new_size = self.text_model.get_input_embeddings().num_embeddings |
| logger.info(f"Resize embedding layer from {base_size} to {new_size} tokens") |
|
|
| d_text = self.text_model.config.hidden_size |
| assert self.vision_model.config.model_type == "vqmodel" |
| d_vision = self.vision_model.model.embed_dim |
| image_pos_size = self.config.image_pos_size |
| sigma = self.config.image_pos_sigma |
| m = math.ceil(image_pos_size / 2) |
| self.image_vocab_size = self.vision_model.model.n_embed |
|
|
| |
| self.proj_vpos = rff.layers.PositionalEncoding(sigma, m) |
| self.proj_vt = nn.Linear(d_vision + image_pos_size, d_text) |
| self.vis_head = nn.Linear(d_text, self.image_vocab_size) |
|
|
| |
| scale_factor = 2 ** (vision_model.model.encoder.num_resolutions - 1) |
| latent_size = self.config.image_size // scale_factor |
| self.num_image_tokens = latent_size**2 |
|
|
| |
| self.post_init() |
|
|
| |
| if config.freeze_text_embeddings: |
| self.text_model.get_input_embeddings().requires_grad_(False) |
|
|
| def tie_weights(self) -> None: |
| self.text_model.tie_weights() |
|
|
| @classmethod |
| def from_pretrained(cls, *args: Any, **kwargs: Any) -> "MarkupDMForCausalLM": |
| assert "config" in kwargs, "Config must be provided" |
| config = kwargs["config"] |
| dtype = kwargs.get("dtype", kwargs.get("torch_dtype", None)) |
|
|
| |
| text_model = AutoModelForCausalLM.from_config( |
| config.text_model, |
| dtype=dtype, |
| attn_implementation=config._attn_implementation, |
| ) |
|
|
| |
| with contextlib.redirect_stdout(open(os.devnull, "w")): |
| vision_model = AutoModel.from_config( |
| config.vision_model, |
| trust_remote_code=True, |
| dtype=dtype, |
| ) |
|
|
| return super().from_pretrained( |
| *args, |
| **kwargs, |
| text_model=text_model, |
| vision_model=vision_model, |
| ) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| inputs_embeds: torch.Tensor | None = None, |
| image_mask: torch.Tensor | None = None, |
| image_pos_ids: torch.Tensor | None = None, |
| labels: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.Tensor | None = None, |
| past_key_values: tuple[tuple[torch.Tensor]] | None = None, |
| use_cache: bool | None = None, |
| output_attentions: bool | None = None, |
| output_hidden_states: bool | None = None, |
| return_dict: bool | None = None, |
| cache_position: torch.Tensor | None = None, |
| num_items_in_batch: int | None = None, |
| **kwargs: Any, |
| ) -> CausalLMOutputWithPast: |
| for key in kwargs.keys(): |
| if kwargs[key] is not None: |
| raise ValueError(f"Unknown argument: {key}={kwargs[key]}") |
|
|
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| if image_mask is None: |
| image_mask = input_ids >= self.config.vocab_size |
|
|
| |
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens( |
| input_ids, |
| image_mask=image_mask, |
| image_pos_ids=image_pos_ids, |
| ) |
|
|
| |
| fwd_kwargs = { |
| "inputs_embeds": inputs_embeds, |
| "attention_mask": attention_mask, |
| "position_ids": position_ids, |
| "past_key_values": past_key_values, |
| "use_cache": use_cache, |
| "output_hidden_states": True, |
| "output_attentions": output_attentions, |
| } |
| if self.config.text_model.model_type == "starcoder2": |
| fwd_kwargs["cache_position"] = cache_position |
| outputs = self.text_model(**fwd_kwargs) |
|
|
| |
| text_logits = outputs.logits[:, :, : self.config.vocab_size] |
|
|
| |
| last_hidden_states = outputs.hidden_states[-1] |
| vision_logits = self.vis_head(last_hidden_states) |
|
|
| if labels is not None: |
| |
| shift_mask = F.pad(image_mask[:, 1:], (0, 1), value=False) |
| text_logits[shift_mask] = -float("inf") |
| vision_logits[~shift_mask] = -float("inf") |
|
|
| |
| logits = torch.cat([text_logits, vision_logits], dim=-1) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function( |
| logits=logits, |
| labels=labels, |
| image_vocab_size=self.image_vocab_size, |
| image_loss_weight=self.config.image_loss_weight, |
| num_items_in_batch=num_items_in_batch, |
| **kwargs, |
| ) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states if output_hidden_states else None, |
| attentions=outputs.attentions, |
| ) |
|
|
| def embed_tokens( |
| self, |
| input_ids: torch.Tensor, |
| image_mask: torch.Tensor | None = None, |
| image_pos_ids: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| if image_mask is None: |
| return self.text_embed(input_ids) |
|
|
| |
| size = input_ids.size() + (self.text_model.config.hidden_size,) |
| inputs_embeds = torch.zeros(size, device=self.device, dtype=self.dtype) |
|
|
| |
| text_embeds = self.text_embed(input_ids[~image_mask]) |
| inputs_embeds[~image_mask] = text_embeds |
|
|
| |
| image_embeds = self.vis_embed(input_ids[image_mask] - self.config.vocab_size) |
|
|
| |
| assert image_pos_ids is not None |
| image_pos = image_pos_ids / self.num_image_tokens |
| image_pos = self.proj_vpos(image_pos.unsqueeze(-1)).to(image_embeds) |
| image_pos = image_pos[image_mask][:, : self.config.image_pos_size] |
| image_embeds = torch.cat([image_embeds, image_pos], dim=-1) |
|
|
| |
| image_embeds = self.proj_vt(image_embeds) |
| inputs_embeds[image_mask] = image_embeds |
|
|
| return inputs_embeds |
|
|
| def text_embed(self, input_ids: torch.Tensor) -> torch.Tensor: |
| return self.text_model.get_input_embeddings()(input_ids) |
|
|
| def vis_embed(self, input_ids: torch.Tensor) -> torch.Tensor: |
| return self.vision_model.model.quantize.embedding(input_ids) |
|
|
| def prepare_inputs_for_generation( |
| self, input_ids: torch.Tensor, **model_kwargs: Any |
| ) -> dict: |
| |
| default_prepare_inputs = self.text_model.prepare_inputs_for_generation |
| inputs = default_prepare_inputs(input_ids, **model_kwargs) |
|
|
| |
| base_ids = torch.arange(self.num_image_tokens, device=self.device) |
| image_pos_ids = torch.zeros_like(input_ids) |
| image_mask_all = input_ids >= self.config.vocab_size |
| for i_batch, image_mask in enumerate(image_mask_all): |
| N = sum(image_mask) |
| pos_ids = base_ids.repeat(N // self.num_image_tokens + 1) |
| image_pos_ids[i_batch, image_mask] = pos_ids[:N] |
| length = inputs["input_ids"].size(1) |
| inputs["image_pos_ids"] = image_pos_ids[:, -length:] |
|
|
| inputs["image_mask"] = inputs["input_ids"] >= self.config.vocab_size |
|
|
| return inputs |
|
|