anananan116 commited on
Commit
8249d39
·
verified ·
1 Parent(s): 570e2b5

Upload 5 files

Browse files
Files changed (1) hide show
  1. modeling_VLM.py +164 -196
modeling_VLM.py CHANGED
@@ -1,197 +1,165 @@
1
- from .modeling_llama import AdapterMLP, DEFAULT_SYSTEM_PROMPT, LlamaForCausalLM
2
- from .configuration_llama import VLMConfig
3
- from .configuration_clip import CLIPConfig
4
- from .visual_modeling import CLIPModel
5
- import torch
6
- from torch import nn
7
- from transformers import PreTrainedModel, PreTrainedTokenizer, AutoProcessor, GenerationMixin
8
-
9
- class VLMPretrainedModel(PreTrainedModel):
10
- config_class = VLMConfig
11
- base_model_prefix = "model"
12
- supports_gradient_checkpointing = False
13
- _no_split_modules = ["LlamaDecoderLayer", "Block"]
14
- _skip_keys_device_placement = "past_key_values"
15
-
16
- def _init_weights(self, module):
17
- std = self.config.initializer_range
18
- if isinstance(module, nn.Linear):
19
- module.weight.data.normal_(mean=0.0, std=std)
20
- if module.bias is not None:
21
- module.bias.data.zero_()
22
- elif isinstance(module, nn.Embedding):
23
- module.weight.data.normal_(mean=0.0, std=std)
24
- if module.padding_idx is not None:
25
- module.weight.data[module.padding_idx].zero_()
26
-
27
- class AtriVLM(VLMPretrainedModel, GenerationMixin):
28
- _tied_weights_keys = ['decoder.lm_head.weight']
29
- def __init__(self, config: VLMConfig):
30
- super().__init__(config)
31
- if config.special_token_map:
32
- self.image_start_token_id = config.special_token_map['Image'][1]
33
- self.image_end_token_id = config.special_token_map['Image_End'][1]
34
- self.caption_token_id = config.special_token_map['Caption'][1]
35
- self.image_token_id = config.special_token_map['Image_Token'][1]
36
- else:
37
- raise ValueError("Special token map not found")
38
- self.image_adapter = AdapterMLP(config)
39
- self.num_patches = config.num_patches
40
- self.processor = AutoProcessor.from_pretrained(config.pretrained_vision_model).image_processor
41
- self.img_place_holder = "<IMGPLH>"
42
- self.img_start_token = "<IMAGE>"
43
- self.img_end_token = "<IMAGE_END>"
44
- self.image_token = "<Image_Token>"
45
- self.decoder = LlamaForCausalLM((config))
46
- if config.load_vision_model:
47
- if isinstance(config.visual_config, dict):
48
- self.visual = CLIPModel(CLIPConfig(**config.visual_config))
49
- else:
50
- self.visual = CLIPModel(config.visual_config)
51
- else:
52
- self.visual = None
53
-
54
- def get_input_embeddings(self):
55
- return self.decoder.get_input_embeddings()
56
-
57
- def set_input_embeddings(self, value):
58
- return self.decoder.set_input_embeddings(value)
59
-
60
- def get_output_embeddings(self):
61
- return self.decoder.lm_head
62
-
63
- def set_output_embeddings(self, new_embeddings):
64
- self.decoder.lm_head = new_embeddings
65
-
66
- def forward(self, input_ids=None, encoded_image=None, labels=None, past_key_values = None, attention_mask = None, inputs_embeds = None, **kwargs):
67
- """
68
- Forward pass for the VLM model that combines image and text embeddings.
69
-
70
- Args:
71
- input_ids (torch.LongTensor): Input token ids of shape (batch_size, seq_len)
72
- encoded_image (torch.FloatTensor): Encoded image features of shape (batch_size, num_patches, hidden_dim)
73
- labels (torch.LongTensor): Labels for computing the language modeling loss
74
- """
75
- if not past_key_values and (encoded_image is not None):
76
- encoded_image = encoded_image.to(self.decoder.get_input_embeddings().weight.dtype)
77
- # Process image features through the adapter
78
- processed_image = self.image_adapter(encoded_image)
79
-
80
- # Get embeddings for all input tokens
81
- token_embeddings = self.decoder.get_input_embeddings()(input_ids)
82
-
83
- # Find positions of image tokens and replace them with processed image embeddings
84
- image_token_positions = (input_ids == self.image_token_id).nonzero(as_tuple=True)
85
- token_embeddings = token_embeddings
86
- token_embeddings[image_token_positions] = processed_image.reshape(-1, processed_image.size(-1))
87
- else:
88
- token_embeddings = self.decoder.get_input_embeddings()(input_ids)
89
- # Call the native forward method with the modified embeddings
90
- outputs = self.decoder._native_forward(
91
- inputs_embeds=token_embeddings,
92
- past_key_values=past_key_values,
93
- attention_mask=attention_mask,
94
- labels=labels,
95
- **kwargs
96
- )
97
-
98
- return outputs
99
-
100
-
101
- def prepare_input_ids_for_generation(self, prompts, images, tokenizer, system_prompt=DEFAULT_SYSTEM_PROMPT):
102
- """
103
- Prepare input ids and images for generation.
104
-
105
- Args:
106
- prompts (List[str]): List of text prompts
107
- images (List[Image]): List of images corresponding to prompts
108
- tokenizer: Tokenizer instance
109
- system_prompt (str): System prompt to be prepended
110
-
111
- Returns:
112
- dict: Contains input_ids, attention_mask, and processed images
113
- """
114
- # Process the images first
115
- processed_images = []
116
- for image in images:
117
- # Process image through vision encoder
118
- pixel_values = self.processor(image, return_tensors="pt")["pixel_values"].to(self.visual.vision_model.embeddings.patch_embedding.weight.device)
119
- image_features = self.visual.encode_image(pixel_values)
120
- processed_images.append(image_features)
121
-
122
- # Stack all processed images
123
- if processed_images:
124
- processed_images = torch.cat(processed_images, dim=0)
125
-
126
- # Process each prompt
127
- formatted_prompts = []
128
- for prompt in prompts:
129
- # Replace image placeholder with tokens
130
- if self.img_place_holder in prompt:
131
- image_token_sequence = (
132
- f"{self.img_start_token}" +
133
- f"{self.image_token}" * self.num_patches +
134
- f"{self.img_end_token}"
135
- )
136
- formatted_prompt = prompt.replace(self.img_place_holder, image_token_sequence)
137
- else:
138
- formatted_prompt = prompt
139
-
140
- # Create conversation format
141
- conversation = [
142
- {"role": "system", "content": system_prompt},
143
- {"role": "user", "content": formatted_prompt},
144
- ]
145
-
146
- # Apply chat template
147
- formatted_conversation = tokenizer.apply_chat_template(
148
- conversation,
149
- tokenize=False,
150
- add_generation_prompt=True
151
- )
152
- formatted_prompts.append(formatted_conversation)
153
-
154
- # Tokenize all prompts together
155
- tokenized_output = tokenizer(
156
- formatted_prompts,
157
- padding=True,
158
- return_tensors="pt",
159
- padding_side="left" # Use left padding since we're generating on the right
160
- )
161
-
162
- return {
163
- "input_ids": tokenized_output["input_ids"],
164
- "attention_mask": tokenized_output["attention_mask"],
165
- "encoded_image": processed_images if processed_images.size(0) > 0 else None
166
- }
167
-
168
- def prepare_for_generation(self, input_ids, encoded_image, **kwargs):
169
- """
170
- Prepare KV cache for generation by processing the image and initial tokens.
171
-
172
- Args:
173
- input_ids (torch.LongTensor): Input token ids of shape (batch_size, seq_len)
174
- encoded_image (torch.FloatTensor): Encoded image features of shape (batch_size, num_patches, hidden_dim)
175
-
176
- Returns:
177
- past_key_values: Tuple containing the key and value states to be used for subsequent generation
178
- """
179
- encoded_image = encoded_image.to(self.decoder.get_input_embeddings().weight.dtype)
180
- # Process image features through the adapter
181
- processed_image = self.image_adapter(encoded_image)
182
-
183
- # Get embeddings for all input tokens
184
- token_embeddings = self.decoder.get_input_embeddings()(input_ids)
185
-
186
- # Find positions of image tokens and replace them with processed image embeddings
187
- image_token_positions = (input_ids == self.image_token_id).nonzero(as_tuple=True)
188
- token_embeddings[image_token_positions] = processed_image.reshape(-1, processed_image.size(-1))
189
-
190
- # Forward pass with cache preparation
191
- outputs = self.decoder._native_forward(
192
- inputs_embeds=token_embeddings,
193
- use_cache=True,
194
- **kwargs
195
- )
196
-
197
  return outputs.past_key_values
 
1
+ from .modeling_llama import AdapterMLP, DEFAULT_SYSTEM_PROMPT, LlamaForCausalLM
2
+ from .configuration_llama import VLMConfig
3
+ from .configuration_clip import CLIPConfig
4
+ from .visual_modeling import CLIPModel
5
+ import torch
6
+ from torch import nn
7
+ from transformers import AutoProcessor
8
+
9
+ class AtriVLM(LlamaForCausalLM):
10
+ def __init__(self, config: VLMConfig):
11
+ super().__init__(config)
12
+ if config.special_token_map:
13
+ self.image_start_token_id = config.special_token_map['Image'][1]
14
+ self.image_end_token_id = config.special_token_map['Image_End'][1]
15
+ self.caption_token_id = config.special_token_map['Caption'][1]
16
+ self.image_token_id = config.special_token_map['Image_Token'][1]
17
+ else:
18
+ raise ValueError("Special token map not found")
19
+ self.image_adapter = AdapterMLP(config)
20
+ self.num_patches = config.num_patches
21
+ self.processor = AutoProcessor.from_pretrained(config.pretrained_vision_model).image_processor
22
+ self.img_place_holder = "<IMGPLH>"
23
+ self.img_start_token = "<IMAGE>"
24
+ self.img_end_token = "<IMAGE_END>"
25
+ self.image_token = "<Image_Token>"
26
+ if config.load_vision_model:
27
+ if isinstance(config.visual_config, dict):
28
+ self.visual = CLIPModel(CLIPConfig(**config.visual_config))
29
+ else:
30
+ self.visual = CLIPModel(config.visual_config)
31
+ else:
32
+ self.visual = None
33
+
34
+ def forward(self, input_ids=None, encoded_image=None, labels=None, past_key_values = None, attention_mask = None, inputs_embeds = None, **kwargs):
35
+ """
36
+ Forward pass for the VLM model that combines image and text embeddings.
37
+
38
+ Args:
39
+ input_ids (torch.LongTensor): Input token ids of shape (batch_size, seq_len)
40
+ encoded_image (torch.FloatTensor): Encoded image features of shape (batch_size, num_patches, hidden_dim)
41
+ labels (torch.LongTensor): Labels for computing the language modeling loss
42
+ """
43
+ if not past_key_values and (encoded_image is not None):
44
+ encoded_image = encoded_image.to(self.get_input_embeddings().weight.dtype)
45
+ # Process image features through the adapter
46
+ processed_image = self.image_adapter(encoded_image)
47
+
48
+ # Get embeddings for all input tokens
49
+ token_embeddings = self.get_input_embeddings()(input_ids)
50
+
51
+ # Find positions of image tokens and replace them with processed image embeddings
52
+ image_token_positions = (input_ids == self.image_token_id).nonzero(as_tuple=True)
53
+ token_embeddings = token_embeddings
54
+ token_embeddings[image_token_positions] = processed_image.reshape(-1, processed_image.size(-1))
55
+ else:
56
+ token_embeddings = self.get_input_embeddings()(input_ids)
57
+ # Call the native forward method with the modified embeddings
58
+ outputs = self._native_forward(
59
+ inputs_embeds=token_embeddings,
60
+ past_key_values=past_key_values,
61
+ attention_mask=attention_mask,
62
+ labels=labels,
63
+ **kwargs
64
+ )
65
+
66
+ return outputs
67
+
68
+
69
+ def prepare_input_ids_for_generation(self, prompts, images, tokenizer, system_prompt=DEFAULT_SYSTEM_PROMPT):
70
+ """
71
+ Prepare input ids and images for generation.
72
+
73
+ Args:
74
+ prompts (List[str]): List of text prompts
75
+ images (List[Image]): List of images corresponding to prompts
76
+ tokenizer: Tokenizer instance
77
+ system_prompt (str): System prompt to be prepended
78
+
79
+ Returns:
80
+ dict: Contains input_ids, attention_mask, and processed images
81
+ """
82
+ # Process the images first
83
+ processed_images = []
84
+ for image in images:
85
+ # Process image through vision encoder
86
+ pixel_values = self.processor(image, return_tensors="pt")["pixel_values"].to(self.visual.vision_model.embeddings.patch_embedding.weight.device)
87
+ image_features = self.visual.encode_image(pixel_values)
88
+ processed_images.append(image_features)
89
+
90
+ # Stack all processed images
91
+ if processed_images:
92
+ processed_images = torch.cat(processed_images, dim=0)
93
+
94
+ # Process each prompt
95
+ formatted_prompts = []
96
+ for prompt in prompts:
97
+ # Replace image placeholder with tokens
98
+ if self.img_place_holder in prompt:
99
+ image_token_sequence = (
100
+ f"{self.img_start_token}" +
101
+ f"{self.image_token}" * self.num_patches +
102
+ f"{self.img_end_token}"
103
+ )
104
+ formatted_prompt = prompt.replace(self.img_place_holder, image_token_sequence)
105
+ else:
106
+ formatted_prompt = prompt
107
+
108
+ # Create conversation format
109
+ conversation = [
110
+ {"role": "system", "content": system_prompt},
111
+ {"role": "user", "content": formatted_prompt},
112
+ ]
113
+
114
+ # Apply chat template
115
+ formatted_conversation = tokenizer.apply_chat_template(
116
+ conversation,
117
+ tokenize=False,
118
+ add_generation_prompt=True
119
+ )
120
+ formatted_prompts.append(formatted_conversation)
121
+
122
+ # Tokenize all prompts together
123
+ tokenized_output = tokenizer(
124
+ formatted_prompts,
125
+ padding=True,
126
+ return_tensors="pt",
127
+ padding_side="left" # Use left padding since we're generating on the right
128
+ )
129
+
130
+ return {
131
+ "input_ids": tokenized_output["input_ids"],
132
+ "attention_mask": tokenized_output["attention_mask"],
133
+ "encoded_image": processed_images if processed_images.size(0) > 0 else None
134
+ }
135
+
136
+ def prepare_for_generation(self, input_ids, encoded_image, **kwargs):
137
+ """
138
+ Prepare KV cache for generation by processing the image and initial tokens.
139
+
140
+ Args:
141
+ input_ids (torch.LongTensor): Input token ids of shape (batch_size, seq_len)
142
+ encoded_image (torch.FloatTensor): Encoded image features of shape (batch_size, num_patches, hidden_dim)
143
+
144
+ Returns:
145
+ past_key_values: Tuple containing the key and value states to be used for subsequent generation
146
+ """
147
+ encoded_image = encoded_image.to(self.get_input_embeddings().weight.dtype)
148
+ # Process image features through the adapter
149
+ processed_image = self.image_adapter(encoded_image)
150
+
151
+ # Get embeddings for all input tokens
152
+ token_embeddings = self.get_input_embeddings()(input_ids)
153
+
154
+ # Find positions of image tokens and replace them with processed image embeddings
155
+ image_token_positions = (input_ids == self.image_token_id).nonzero(as_tuple=True)
156
+ token_embeddings[image_token_positions] = processed_image.reshape(-1, processed_image.size(-1))
157
+
158
+ # Forward pass with cache preparation
159
+ outputs = self._native_forward(
160
+ inputs_embeds=token_embeddings,
161
+ use_cache=True,
162
+ **kwargs
163
+ )
164
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  return outputs.past_key_values