| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| from einops import rearrange |
| from transformers import ( |
| AutoConfig, |
| AutoModelForCausalLM, |
| LlamaConfig, |
| LlamaForCausalLM, |
| PreTrainedModel, |
| GenerationMixin |
| ) |
| import numpy as np |
| from transformers.configuration_utils import PretrainedConfig |
|
|
| from .clip_encoder import CLIPVisionTower |
| from .siglip_vit import create_siglip_vit |
| from .projector import MlpProjector |
| from .configuration_vlm import AttrDict, MultiModalityConfig, VisionConfig, AlignerConfig, GenVisionConfig, GenHeadConfig, GenAlignerConfig |
| from .vq_model import VQ_models |
|
|
|
|
| class vision_head(torch.nn.Module): |
| def __init__(self, params): |
| super().__init__() |
| self.output_mlp_projector = torch.nn.Linear( |
| params.n_embed, params.image_token_embed |
| ) |
| self.vision_activation = torch.nn.GELU() |
| self.vision_head = torch.nn.Linear( |
| params.image_token_embed, params.image_token_size |
| ) |
|
|
| def forward(self, x): |
| x = self.output_mlp_projector(x) |
| x = self.vision_activation(x) |
| x = self.vision_head(x) |
| return x |
|
|
|
|
| def model_name_to_cls(cls_name): |
| if "MlpProjector" in cls_name: |
| cls = MlpProjector |
|
|
| elif "CLIPVisionTower" in cls_name: |
| cls = CLIPVisionTower |
|
|
| elif "VQ" in cls_name: |
| from .vq_model import VQ_models |
|
|
| cls = VQ_models[cls_name] |
| elif "vision_head" in cls_name: |
| cls = vision_head |
| else: |
| raise ValueError(f"class_name {cls_name} is invalid.") |
|
|
| return cls |
|
|
|
|
| class MultiModalityPreTrainedModel(PreTrainedModel): |
| config_class = MultiModalityConfig |
| base_model_prefix = "multi_modality" |
| _no_split_modules = [] |
| _skip_keys_device_placement = "past_key_values" |
|
|
|
|
| class MultiModalityCausalLM(MultiModalityPreTrainedModel): |
| def __init__(self, config: MultiModalityConfig): |
| super().__init__(config) |
|
|
| vision_config = config.vision_config |
| vision_cls = model_name_to_cls(vision_config.cls) |
| self.vision_model = vision_cls(**vision_config.params) |
|
|
| aligner_config = config.aligner_config |
| aligner_cls = model_name_to_cls(aligner_config.cls) |
| self.aligner = aligner_cls(aligner_config.params) |
|
|
| gen_vision_config = config.gen_vision_config |
| gen_vision_cls = model_name_to_cls(gen_vision_config.cls) |
| self.gen_vision_model = gen_vision_cls() |
|
|
| gen_aligner_config = config.gen_aligner_config |
| gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls) |
| self.gen_aligner = gen_aligner_cls(gen_aligner_config.params) |
|
|
| gen_head_config = config.gen_head_config |
| gen_head_cls = model_name_to_cls(gen_head_config.cls) |
| self.gen_head = gen_head_cls(gen_head_config.params) |
|
|
| self.gen_embed = torch.nn.Embedding( |
| gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed |
| ) |
|
|
| language_config = config.language_config |
| self.language_model = LlamaForCausalLM(language_config) |
|
|
| def prepare_inputs_embeds( |
| self, |
| input_ids: torch.LongTensor, |
| pixel_values: torch.FloatTensor, |
| images_seq_mask: torch.LongTensor, |
| images_emb_mask: torch.LongTensor, |
| **kwargs, |
| ): |
| """ |
| |
| Args: |
| input_ids (torch.LongTensor): [b, T] |
| pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] |
| images_seq_mask (torch.BoolTensor): [b, T] |
| images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] |
| |
| assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) |
| |
| Returns: |
| input_embeds (torch.Tensor): [b, T, D] |
| """ |
|
|
| bs, n = pixel_values.shape[0:2] |
| images = rearrange(pixel_values, "b n c h w -> (b n) c h w") |
| |
| images_embeds = self.aligner(self.vision_model(images)) |
|
|
| |
| images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) |
| |
| images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") |
|
|
| |
| input_ids[input_ids < 0] = 0 |
| inputs_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
|
| |
| inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] |
|
|
| return inputs_embeds |
|
|
| def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): |
| return self.gen_aligner(self.gen_embed(image_ids)) |
|
|
| def forward( |
| self, |
| input_ids, |
| pixel_values=None, |
| past_key_values=None, |
| inputs_embeds=None, |
| attention_mask=None, |
| position_ids=None, |
| images_seq_mask=None, |
| images_emb_mask=None, |
| **kwargs, |
| ): |
| if inputs_embeds is None: |
| inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs) |
| return self.language_model.forward( |
| input_ids=None, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| **kwargs, |
| ) |
|
|
| def generate( |
| self, |
| input_ids=None, |
| pixel_values=None, |
| past_key_values=None, |
| inputs_embeds=None, |
| attention_mask=None, |
| position_ids=None, |
| images_seq_mask=None, |
| images_emb_mask=None, |
| **kwargs |
| ): |
| if inputs_embeds is None: |
| inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs) |
| return self.language_model.generate(inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, **kwargs) |
|
|
| @torch.no_grad() |
| def generate_image( |
| self, |
| processor, |
| prompt: str, |
| temperature: float = 1, |
| parallel_size: int = 16, |
| cfg_weight: float = 5, |
| image_token_num_per_image: int = 576, |
| img_size: int = 384, |
| patch_size: int = 16, |
| generator=None |
| ): |
| from PIL import Image |
|
|
| conversation = [ |
| { |
| "role": "User", |
| "content": prompt, |
| }, |
| {"role": "Assistant", "content": ""}, |
| ] |
|
|
| sft_format = processor.apply_sft_template_for_multi_turn_prompts( |
| conversations=conversation, |
| sft_format=processor.sft_format, |
| system_prompt="", |
| ) |
| prompt = sft_format + processor.image_start_tag |
| input_ids = processor.tokenizer.encode(prompt) |
| input_ids = torch.LongTensor(input_ids) |
|
|
| tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int) |
| for i in range(parallel_size * 2): |
| tokens[i, :] = input_ids |
| if i % 2 != 0: |
| tokens[i, 1:-1] = processor.pad_id |
|
|
| inputs_embeds = self.language_model.get_input_embeddings()(tokens) |
|
|
| generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int) |
| past_key_values = None |
|
|
| for i in range(image_token_num_per_image): |
| outputs = self.language_model.model.forward( |
| input_ids=None, |
| inputs_embeds=inputs_embeds, |
| use_cache=True, |
| past_key_values=past_key_values, |
| ) |
| hidden_states = outputs.last_hidden_state |
| past_key_values = outputs.past_key_values |
| logits = self.gen_head(hidden_states[:, -1, :]) |
| logit_cond = logits[0::2, :] |
| logit_uncond = logits[1::2, :] |
|
|
| logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) |
| probs = torch.softmax(logits / temperature, dim=-1) |
|
|
| next_token = torch.multinomial(probs, num_samples=1) if generator is None else torch.multinomial(probs, num_samples=1, generator=generator) |
| generated_tokens[:, i] = next_token.squeeze(dim=-1) |
|
|
| next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) |
| img_embeds = self.prepare_gen_img_embeds(next_token) |
| inputs_embeds = img_embeds.unsqueeze(dim=1) |
| dec = self.gen_vision_model.decode_code( |
| generated_tokens.to(dtype=torch.int), [parallel_size, 8, img_size // patch_size, img_size // patch_size] |
| ) |
| dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) |
|
|
| dec = np.clip((dec + 1) / 2 * 255, 0, 255) |
|
|
| visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8) |
| visual_img[:, :, :] = dec |
|
|
| images = [] |
|
|
| for i in range(parallel_size): |
| images.append(Image.fromarray(visual_img[i])) |
|
|
| return images |
|
|
|
|
| AutoConfig.register("vision", VisionConfig) |
| AutoConfig.register("aligner", AlignerConfig) |
| AutoConfig.register("gen_vision", GenVisionConfig) |
| AutoConfig.register("gen_aligner", GenAlignerConfig) |
| AutoConfig.register("gen_head", GenHeadConfig) |
| AutoConfig.register("multi_modality", MultiModalityConfig) |
| AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM) |
|
|