jiang-cc commited on
Commit
d9dc657
·
verified ·
1 Parent(s): cc36dfc

fix: sync all compat fixes from AD-Copilot (imports, get_rope_index kwargs, remove sdpa hardcode)

Browse files
Files changed (1) hide show
  1. modeling_ad_copilot.py +758 -0
modeling_ad_copilot.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from typing import Any, Callable, Optional, Union
6
+
7
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoModelForImageTextToText
8
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
9
+ Qwen2_5_VisionTransformerPretrainedModel,
10
+ Qwen2_5_VLModel,
11
+ Qwen2_5_VLMLP,
12
+ ALL_ATTENTION_FUNCTIONS,
13
+ )
14
+ try:
15
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm
16
+ except ImportError:
17
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
18
+ try:
19
+ from transformers.image_utils import ImageInput
20
+ except ImportError:
21
+ ImageInput = Any
22
+ try:
23
+ from transformers.tokenization_utils import TextInput, PreTokenizedInput
24
+ except ImportError:
25
+ TextInput = str
26
+ PreTokenizedInput = list
27
+ try:
28
+ from transformers.video_utils import VideoInput
29
+ except ImportError:
30
+ VideoInput = Any
31
+ try:
32
+ from transformers.feature_extraction_utils import BatchFeature
33
+ except ImportError:
34
+ from transformers import BatchFeature
35
+
36
+ from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLConfig
37
+ try:
38
+ from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessorKwargs
39
+ except ImportError:
40
+ Qwen2_5_VLProcessorKwargs = dict
41
+
42
+ class ADCopilotConfig(Qwen2_5_VLConfig):
43
+ model_type = "ad_copilot"
44
+ def __init__(self, **kwargs):
45
+ super().__init__(**kwargs)
46
+ self.vision_config.compare_token_size = 100
47
+ self.architectures = ["ADCopilotVLForConditionalGeneration"]
48
+ self.sequence_compare = True
49
+
50
+ class ADCopilotProcessor(Qwen2_5_VLProcessor):
51
+ config_class = ADCopilotConfig
52
+ def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
53
+ super().__init__(image_processor, tokenizer, video_processor, chat_template, **kwargs)
54
+ self.compare_token_size = 100 if "compare_token_size" not in kwargs else kwargs["compare_token_size"]
55
+
56
+ def __call__(
57
+ self,
58
+ images: ImageInput = None,
59
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
60
+ videos: VideoInput = None,
61
+ **kwargs,
62
+ ) -> BatchFeature:
63
+ """
64
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
65
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
66
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
67
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
68
+
69
+ Args:
70
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
71
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
72
+ tensor. Both channels-first and channels-last formats are supported.
73
+ text (`str`, `list[str]`, `list[list[str]]`):
74
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
75
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
76
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
77
+ videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
78
+ The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
79
+ tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
80
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
81
+ If set, will return tensors of a particular framework. Acceptable values are:
82
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
83
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
84
+ - `'np'`: Return NumPy `np.ndarray` objects.
85
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
86
+
87
+ Returns:
88
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
89
+
90
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
91
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
92
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
93
+ `None`).
94
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
95
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
96
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
97
+ - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
98
+ - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
99
+ """
100
+ output_kwargs = self._merge_kwargs(
101
+ Qwen2_5_VLProcessorKwargs,
102
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
103
+ **kwargs,
104
+ )
105
+
106
+ image_inputs = videos_inputs = {}
107
+ if images is not None:
108
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
109
+ image_grid_thw = image_inputs["image_grid_thw"]
110
+
111
+ if videos is not None:
112
+ fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
113
+ videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
114
+ video_grid_thw = videos_inputs["video_grid_thw"]
115
+
116
+ if isinstance(fps, (int, float)):
117
+ second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
118
+ elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
119
+ second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps]
120
+ else:
121
+ raise ValueError(
122
+ f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
123
+ )
124
+ videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
125
+
126
+ if not isinstance(text, list):
127
+ text = [text]
128
+
129
+ text = text.copy() # below lines change text in-place
130
+ if images is not None:
131
+ merge_length = self.image_processor.merge_size**2
132
+ index = 0
133
+ for i in range(len(text)):
134
+ while self.image_token in text[i]:
135
+ num_image_tokens = image_grid_thw[index].prod() // merge_length
136
+ # text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens), 1)
137
+ text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens + self.compare_token_size), 1)
138
+ index += 1
139
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
140
+
141
+ if videos is not None:
142
+ merge_length = self.video_processor.merge_size**2
143
+ index = 0
144
+ for i in range(len(text)):
145
+ while self.video_token in text[i]:
146
+ num_video_tokens = video_grid_thw[index].prod() // merge_length
147
+ text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1)
148
+ index += 1
149
+ text[i] = text[i].replace("<|placeholder|>", self.video_token)
150
+
151
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
152
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
153
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
154
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
155
+
156
+ if return_mm_token_type_ids:
157
+ array_ids = np.array(text_inputs["input_ids"])
158
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
159
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
160
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
161
+
162
+ return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
163
+
164
+
165
+ class OptimizedCrossAttention(nn.Module):
166
+ """
167
+ 仿照 Qwen2_5_VLVisionAttention 结构的优化 Cross Attention
168
+ """
169
+ def __init__(self, config, is_cross_attention=True):
170
+ super().__init__()
171
+ self.config = config
172
+ self.dim = config.hidden_size
173
+ self.num_heads = config.num_heads
174
+ self.head_dim = self.dim // self.num_heads
175
+ self.scaling = self.head_dim**-0.5
176
+ self.attention_dropout = 0.0
177
+ self.is_causal = False # cross attention 不需要因果掩码
178
+ self.is_cross_attention = is_cross_attention
179
+
180
+ if is_cross_attention:
181
+ # Cross attention: Q 来自一个序列,K、V 来自另一个序列
182
+ self.q_proj = nn.Linear(self.dim, self.dim, bias=True)
183
+ self.kv = nn.Linear(self.dim, self.dim * 2, bias=True) # 融合 K、V
184
+ else:
185
+ # Self attention: Q、K、V 来自同一个序列
186
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) # 融合 Q、K、V
187
+
188
+ self.proj = nn.Linear(self.dim, self.dim, bias=True)
189
+
190
+ def forward(
191
+ self,
192
+ query_states: torch.Tensor,
193
+ key_value_states: Optional[torch.Tensor] = None,
194
+ attention_mask: Optional[torch.Tensor] = None,
195
+ cu_seqlens: Optional[torch.Tensor] = None, # 只FA2用
196
+ kv_cu_seqlens: Optional[torch.Tensor] = None,# 只FA2用
197
+ **kwargs,
198
+ ) -> torch.Tensor:
199
+ # 允许 query_states [B,T,d] 或 [T,d],自动扩展 batch 维
200
+ orig_2d = False
201
+ if query_states.dim() == 2:
202
+ query_states = query_states.unsqueeze(0)
203
+ orig_2d = True
204
+
205
+ batch_size, seq_len_q, _ = query_states.shape
206
+
207
+ # Q/K/V投影
208
+ if self.is_cross_attention and key_value_states is not None:
209
+ if key_value_states.dim() == 2:
210
+ key_value_states = key_value_states.unsqueeze(0)
211
+ q = self.q_proj(query_states)
212
+ kv = self.kv(key_value_states)
213
+ seq_len_kv = kv.shape[1]
214
+ k, v = kv.reshape(batch_size, seq_len_kv, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0)
215
+ q = q.reshape(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2)
216
+ else:
217
+ if key_value_states is None:
218
+ key_value_states = query_states
219
+ qkv = self.qkv(query_states)
220
+ q, k, v = qkv.reshape(batch_size, seq_len_q, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0)
221
+
222
+ # 选用哪个 attention kernel
223
+ attn_impl = getattr(self.config, '_attn_implementation', 'sdpa')
224
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[attn_impl]
225
+
226
+ # ========= 支持 FA2 ==========
227
+ if attn_impl == "flash_attention_2":
228
+ # Qwen2_5 之所以能支持 FA2,是因为准备了 flatten+cu_seqlens
229
+ # 这里假设 query_states/key_value_states 按 batch 维是变长的
230
+
231
+ # 检查 cu_seqlens,有就用,否则尝试自动生成
232
+ if cu_seqlens is None:
233
+ # 默认把每个batch都视为长度=seq_len_q
234
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device)
235
+ if kv_cu_seqlens is None:
236
+ cu_seqlens_k = torch.arange(0, (batch_size + 1) * k.shape[2], step=k.shape[2], dtype=torch.int32, device=k.device)
237
+ else:
238
+ cu_seqlens_k = kv_cu_seqlens
239
+
240
+ # flatten [B, nH, T, d] -> [total_T, nH, d]
241
+ # 注意!FlashAttn2是 (total, nH, d),不是 (nH, total, d),和普通实现不一样
242
+ # 更安全的 flatten 方式
243
+ # [B, nH, T, d] -> [B, T, nH, d] -> [total_T, nH, d]
244
+ q_ = q.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim)
245
+ k_ = k.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim)
246
+ v_ = v.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim)
247
+
248
+ max_seqlen_q = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
249
+ max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()
250
+
251
+ attn_output, _ = attention_interface(
252
+ self,
253
+ q_,
254
+ k_,
255
+ v_,
256
+ attention_mask=None,
257
+ scaling=self.scaling,
258
+ dropout=0.0 if not self.training else self.attention_dropout,
259
+ cu_seq_lens_q=cu_seqlens,
260
+ cu_seq_lens_k=cu_seqlens_k,
261
+ max_length_q=max_seqlen_q,
262
+ max_length_k=max_seqlen_k,
263
+ is_causal=self.is_causal,
264
+ **kwargs,
265
+ )
266
+
267
+ # 更简洁的输出重构
268
+ # [total_q, nH, d] -> [B, seq_len_q, nH, d]
269
+ attn_output = attn_output.view(batch_size, seq_len_q, self.num_heads, self.head_dim).contiguous()
270
+ else:
271
+ # 普通实现,下游实现就是 [B, nH, T, d]
272
+ attn_output, _ = attention_interface(
273
+ self,
274
+ q, k, v,
275
+ attention_mask=attention_mask,
276
+ scaling=self.scaling,
277
+ dropout=0.0 if not self.training else self.attention_dropout,
278
+ is_causal=self.is_causal,
279
+ **kwargs,
280
+ )
281
+ # attn_output: [B, nH, seq_q, d]
282
+ attn_output = attn_output.transpose(1, 2).contiguous() # [B, seq_q, nH, d]
283
+
284
+ attn_output = attn_output.reshape(batch_size, seq_len_q, self.dim) # [B, seq_q, D]
285
+ attn_output = self.proj(attn_output)
286
+ if orig_2d:
287
+ attn_output = attn_output.squeeze(0)
288
+ return attn_output.contiguous()
289
+
290
+
291
+ class ADCopilotCompareVisualEncoder(nn.Module):
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.config = config
295
+ self.sequence_compare = getattr(config, "sequence_compare", True)
296
+ self.hidden_size = config.hidden_size
297
+ # self.token_size = 100 * (config.spatial_merge_size**2) if "compare_token_size" not in config else config.compare_token_size * (config.spatial_merge_size**2)
298
+ self.token_size = 100 if "compare_token_size" not in config else config.compare_token_size
299
+ # Encoder 部分:双向图像特征交互
300
+ # 第一个cross attention: previous attend to current
301
+ self.encoder_cross_attn1 = OptimizedCrossAttention(config, is_cross_attention=True)
302
+ # 第二个cross attention: current attend to previous
303
+ self.encoder_cross_attn2 = OptimizedCrossAttention(config, is_cross_attention=True)
304
+
305
+ self.encoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
306
+ self.encoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
307
+ self.encoder_norm3 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
308
+ self.encoder_norm4 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
309
+ self.encoder_mlp1 = Qwen2_5_VLMLP(config)
310
+ self.encoder_mlp2 = Qwen2_5_VLMLP(config)
311
+
312
+ # Decoder 部分:Query 与编码特征交互
313
+ # 可学习的 Query Embeddings
314
+ self.query_embeddings = nn.Parameter(
315
+ torch.empty(self.token_size, self.hidden_size)
316
+ )
317
+ # 只保留 Cross Attention for queries to attend to encoded features
318
+ self.decoder_cross_attn = OptimizedCrossAttention(config, is_cross_attention=True)
319
+
320
+ self.decoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
321
+ self.decoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6)
322
+ self.decoder_mlp = Qwen2_5_VLMLP(config)
323
+
324
+ self.compare_projector = nn.Linear(config.hidden_size, config.out_hidden_size)
325
+
326
+ def init_query_embeddings(self):
327
+ nn.init.normal_(self.query_embeddings, mean=0.0, std=0.02)
328
+
329
+ def forward(self, images_hidden_states: list) -> torch.Tensor:
330
+ """
331
+ Args:
332
+ images_hidden_states: List of tensor, each tensor has shape [seq_len, hidden_size]
333
+
334
+ Returns:
335
+ Tensor of shape [total_images, token_size, hidden_size]
336
+ """
337
+ if not images_hidden_states:
338
+ return torch.empty(0, self.token_size, self.hidden_size)
339
+
340
+ # 检查 query_embeddings 是否包含 NaN
341
+ if torch.isnan(self.query_embeddings).any():
342
+ print("警告:query_embeddings 包含 NaN 值")
343
+ # nn.init.normal_(self.query_embeddings, mean=0.0, std=0.02)
344
+
345
+ # 获取每个图像的序列长度
346
+ seq_lengths = [state.size(0) for state in images_hidden_states]
347
+ max_seq_len = max(seq_lengths)
348
+ batch_size = len(images_hidden_states)
349
+ device = images_hidden_states[0].device
350
+ dtype = images_hidden_states[0].dtype
351
+
352
+ # 将所有图像填充到相同长度并堆叠
353
+ padded_states = []
354
+ attention_masks = []
355
+ for state in images_hidden_states:
356
+ pad_len = max_seq_len - state.size(0)
357
+ if pad_len > 0:
358
+ # 填充序列
359
+ padded_state = F.pad(state, (0, 0, 0, pad_len), mode='constant', value=0)
360
+ # 创建注意力掩码
361
+ attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device)
362
+ attention_mask[state.size(0):] = False
363
+ else:
364
+ padded_state = state
365
+ attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device)
366
+ padded_states.append(padded_state)
367
+ attention_masks.append(attention_mask)
368
+
369
+ # [batch_size, max_seq_len, hidden_size]
370
+ batched_states = torch.stack(padded_states)
371
+ # [batch_size, max_seq_len]
372
+ attention_masks = torch.stack(attention_masks)
373
+
374
+ # 创建循环移位的状态用于对比
375
+ # 对于第一个图像,使用自身作为previous
376
+ previous_states = torch.roll(batched_states, shifts=1, dims=0)
377
+ previous_masks = torch.roll(attention_masks, shifts=1, dims=0)
378
+
379
+ if previous_states.size(0) > 1 and self.sequence_compare:
380
+ previous_states[0] = previous_states[1]
381
+ previous_masks[0] = previous_masks[1]
382
+
383
+ # Encoder: 批量处理所有图像
384
+ encoded_features = self._encoder_forward(
385
+ batched_states, # [batch_size, max_seq_len, hidden_size]
386
+ previous_states, # [batch_size, max_seq_len, hidden_size]
387
+ attention_masks, # [batch_size, max_seq_len]
388
+ previous_masks # [batch_size, max_seq_len]
389
+ )
390
+
391
+ # Decoder: 批量处理所有图像
392
+ # 扩展query_embeddings到batch维度
393
+ batch_queries = self.query_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
394
+ # [batch_size, token_size, hidden_size]
395
+ compare_visual_embeds = self._decoder_forward(
396
+ batch_queries,
397
+ encoded_features,
398
+ torch.ones(batch_size, self.token_size, dtype=torch.bool, device=device), # query掩码
399
+ attention_masks # encoded特征的掩码
400
+ )
401
+
402
+ # 记录每个batch的token数量
403
+ batch_size = compare_visual_embeds.size(0)
404
+ token_size = compare_visual_embeds.size(1)
405
+ # 将所有batch的数据拼接在一起
406
+ # [batch_size * token_size, hidden_size]
407
+ flattened_embeds = compare_visual_embeds.view(-1, compare_visual_embeds.size(-1))
408
+ merged = self.compare_projector(flattened_embeds) # [batch_size * token_size, merged_hidden_size]
409
+ merged_token_size = token_size
410
+ # [batch_size, merged_token_size, merged_hidden_size]
411
+ compare_visual_embeds = merged.view(batch_size, merged_token_size, -1)
412
+
413
+ return compare_visual_embeds # [batch_size, token_size, out_hidden_size]
414
+
415
+ def _encoder_forward(self, current_features, previous_features, current_mask=None, previous_mask=None):
416
+ """
417
+ Encoder: 双向图像特征交互
418
+ Args:
419
+ current_features: [batch_size, seq_len, hidden_size]
420
+ previous_features: [batch_size, seq_len, hidden_size]
421
+ current_mask: [batch_size, seq_len]
422
+ previous_mask: [batch_size, seq_len]
423
+ """
424
+ # 第一步:previous attend to current
425
+ residual = previous_features
426
+
427
+ # Layer norm
428
+ previous_normed = self.encoder_norm1(previous_features)
429
+ current_normed1 = self.encoder_norm1(current_features)
430
+
431
+ # Cross attention: previous attend to current
432
+ cross_attn_output1 = self.encoder_cross_attn1(
433
+ query_states=previous_normed,
434
+ key_value_states=current_normed1,
435
+ attention_mask=current_mask.unsqueeze(1).unsqueeze(2) if current_mask is not None else None
436
+ )
437
+
438
+ # Residual connection
439
+ previous_features = residual + cross_attn_output1
440
+
441
+ # MLP for previous features
442
+ residual = previous_features
443
+ mlp_input1 = self.encoder_norm2(previous_features)
444
+ mlp_output1 = self.encoder_mlp1(mlp_input1)
445
+ previous_features = residual + mlp_output1
446
+
447
+ # 第二步:current attend to previous (enhanced)
448
+ residual = current_features
449
+
450
+ # Layer norm
451
+ current_normed2 = self.encoder_norm3(current_features)
452
+ previous_normed2 = self.encoder_norm3(previous_features)
453
+
454
+ # Cross attention: current attend to previous
455
+ cross_attn_output2 = self.encoder_cross_attn2(
456
+ query_states=current_normed2,
457
+ key_value_states=previous_normed2,
458
+ attention_mask=previous_mask.unsqueeze(1).unsqueeze(2) if previous_mask is not None else None
459
+ )
460
+
461
+ # Residual connection
462
+ current_features = residual + cross_attn_output2
463
+
464
+ # MLP for current features
465
+ residual = current_features
466
+ mlp_input2 = self.encoder_norm4(current_features)
467
+ mlp_output2 = self.encoder_mlp2(mlp_input2)
468
+ # current_features = residual + mlp_output2
469
+ # 修改为减法
470
+ current_features = residual - mlp_output2
471
+ return current_features
472
+
473
+ def _decoder_forward(self, queries, encoded_features, query_mask=None, encoded_mask=None):
474
+ """
475
+ Decoder: Query 与编码特征交互
476
+ Args:
477
+ queries: [batch_size, token_size, hidden_size]
478
+ encoded_features: [batch_size, seq_len, hidden_size]
479
+ query_mask: [batch_size, token_size]
480
+ encoded_mask: [batch_size, seq_len]
481
+ """
482
+ # Cross attention: queries attend to encoded features
483
+ residual = queries
484
+ queries_normed = self.decoder_norm1(queries)
485
+ encoded_normed = self.decoder_norm1(encoded_features)
486
+
487
+ cross_attn_output = self.decoder_cross_attn(
488
+ query_states=queries_normed,
489
+ key_value_states=encoded_normed,
490
+ attention_mask=encoded_mask.unsqueeze(1).unsqueeze(2) if encoded_mask is not None else None
491
+ )
492
+
493
+ queries = residual + cross_attn_output
494
+
495
+ # MLP
496
+ residual = queries
497
+ mlp_input = self.decoder_norm2(queries)
498
+ mlp_output = self.decoder_mlp(mlp_input)
499
+ queries = residual + mlp_output
500
+
501
+ return queries # [batch_size, token_size, hidden_size]
502
+
503
+
504
+ # 先把组件继承出来方便修改
505
+ class ADCopilotVisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrainedModel):
506
+ def __init__(self, config, *inputs, **kwargs) -> None:
507
+ super().__init__(config, *inputs, **kwargs)
508
+ self.compare_visual_encoder = ADCopilotCompareVisualEncoder(config)
509
+
510
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
511
+ """
512
+ Args:
513
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
514
+ The final hidden states of the model.
515
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
516
+ The temporal, height and width of feature shape of each image in LLM.
517
+
518
+ Returns:
519
+ `torch.Tensor`: hidden_states, compare_visual_embeds.
520
+ """
521
+ hidden_states = self.patch_embed(hidden_states)
522
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
523
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
524
+ cu_window_seqlens = torch.tensor(
525
+ cu_window_seqlens,
526
+ device=hidden_states.device,
527
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
528
+ )
529
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
530
+
531
+ seq_len, _ = hidden_states.size()
532
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
533
+ hidden_states = hidden_states[window_index, :, :]
534
+ hidden_states = hidden_states.reshape(seq_len, -1)
535
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
536
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
537
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
538
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
539
+ position_embeddings = (emb.cos(), emb.sin())
540
+
541
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
542
+ dim=0,
543
+ # Select dtype based on the following factors:
544
+ # - FA2 requires that cu_seqlens_q must have dtype int32
545
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
546
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
547
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
548
+ )
549
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
550
+
551
+ for layer_num, blk in enumerate(self.blocks):
552
+ if layer_num in self.fullatt_block_indexes:
553
+ cu_seqlens_now = cu_seqlens
554
+ else:
555
+ cu_seqlens_now = cu_window_seqlens
556
+
557
+ hidden_states = blk(
558
+ hidden_states,
559
+ cu_seqlens=cu_seqlens_now,
560
+ position_embeddings=position_embeddings,
561
+ **kwargs,
562
+ )
563
+
564
+ split_sizes = grid_thw.prod(-1).tolist()
565
+ splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes)
566
+ # [total_images, token_size, hidden_size]
567
+ compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger)
568
+
569
+
570
+ hidden_states = self.merger(hidden_states)
571
+ reverse_indices = torch.argsort(window_index)
572
+ hidden_states = hidden_states[reverse_indices, :]
573
+
574
+ return hidden_states, compare_visual_embeds
575
+
576
+ class ADCopilotVLModel(Qwen2_5_VLModel):
577
+ def __init__(self, config):
578
+ super().__init__(config)
579
+ self.visual = ADCopilotVisionTransformerPretrainedModel._from_config(config.vision_config)
580
+ self.compare_token_size = config.vision_config.compare_token_size
581
+ # self.learnable_image_embeddings = nn.Parameter(
582
+ # torch.randn(100, config.hidden_size) * 0.02 # 使用小的初始化值
583
+ # )
584
+
585
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
586
+ """
587
+ Encodes images into continuous embeddings that can be forwarded to the language model.
588
+
589
+ Args:
590
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
591
+ The tensors corresponding to the input images.
592
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
593
+ The temporal, height and width of feature shape of each image in LLM.
594
+ """
595
+ pixel_values = pixel_values.type(self.visual.dtype)
596
+ image_embeds, compare_visual_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
597
+ # 每个图像添加了对比感知token
598
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
599
+ image_embeds = torch.split(image_embeds, split_sizes)
600
+
601
+ # 将图像嵌入和对比视觉嵌入拼接
602
+ enhanced_image_embeds = []
603
+ for i, embeds in enumerate(image_embeds):
604
+ # 确保 compare_visual_embeds[i] 与 embeds 在相同设备和数据类型
605
+ compare_embed = compare_visual_embeds[i].to(device=embeds.device, dtype=embeds.dtype)
606
+ enhanced_embeds = torch.cat([embeds, compare_embed], dim=0)
607
+ enhanced_image_embeds.append(enhanced_embeds)
608
+
609
+ # image_embeds = torch.cat(enhanced_image_embeds, dim=0)
610
+ return enhanced_image_embeds
611
+
612
+ def get_rope_index(self, input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
613
+ return self.get_rope_index_with_compare_token(input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask)
614
+
615
+ def get_rope_index_with_compare_token(
616
+ self,
617
+ input_ids: Optional[torch.LongTensor] = None,
618
+ image_grid_thw: Optional[torch.LongTensor] = None,
619
+ video_grid_thw: Optional[torch.LongTensor] = None,
620
+ second_per_grid_ts: Optional[torch.Tensor] = None,
621
+ attention_mask: Optional[torch.Tensor] = None,
622
+ ) -> tuple[torch.Tensor, torch.Tensor]:
623
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
624
+ image_token_id = self.config.image_token_id
625
+ video_token_id = self.config.video_token_id
626
+ vision_start_token_id = self.config.vision_start_token_id
627
+ mrope_position_deltas = []
628
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
629
+ total_input_ids = input_ids
630
+ if attention_mask is None:
631
+ attention_mask = torch.ones_like(total_input_ids)
632
+ position_ids = torch.ones(
633
+ 3,
634
+ input_ids.shape[0],
635
+ input_ids.shape[1],
636
+ dtype=input_ids.dtype,
637
+ device=input_ids.device,
638
+ )
639
+ image_index, video_index = 0, 0
640
+ attention_mask = attention_mask.to(total_input_ids.device)
641
+ for i, input_ids in enumerate(total_input_ids):
642
+ input_ids = input_ids[attention_mask[i] == 1]
643
+ image_nums, video_nums = 0, 0
644
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
645
+ vision_tokens = input_ids[vision_start_indices + 1]
646
+ image_nums = (vision_tokens == image_token_id).sum()
647
+ video_nums = (vision_tokens == video_token_id).sum()
648
+ input_tokens = input_ids.tolist()
649
+ llm_pos_ids_list: list = []
650
+ st = 0
651
+ remain_images, remain_videos = image_nums, video_nums
652
+ for vision_index in range(image_nums + video_nums):
653
+ if image_token_id in input_tokens and remain_images > 0:
654
+ ed_image = input_tokens.index(image_token_id, st)
655
+ else:
656
+ ed_image = len(input_tokens) + 1
657
+ if video_token_id in input_tokens and remain_videos > 0:
658
+ ed_video = input_tokens.index(video_token_id, st)
659
+ else:
660
+ ed_video = len(input_tokens) + 1
661
+ if ed_image < ed_video:
662
+ t, h, w = (
663
+ image_grid_thw[image_index][0],
664
+ image_grid_thw[image_index][1],
665
+ image_grid_thw[image_index][2],
666
+ )
667
+ second_per_grid_t = 0
668
+ image_index += 1
669
+ remain_images -= 1
670
+ ed = ed_image
671
+
672
+ else:
673
+ t, h, w = (
674
+ video_grid_thw[video_index][0],
675
+ video_grid_thw[video_index][1],
676
+ video_grid_thw[video_index][2],
677
+ )
678
+ if second_per_grid_ts is not None:
679
+ second_per_grid_t = second_per_grid_ts[video_index]
680
+ else:
681
+ second_per_grid_t = 1.0
682
+ video_index += 1
683
+ remain_videos -= 1
684
+ ed = ed_video
685
+ llm_grid_t, llm_grid_h, llm_grid_w = (
686
+ t.item(),
687
+ h.item() // spatial_merge_size,
688
+ w.item() // spatial_merge_size,
689
+ )
690
+ text_len = ed - st
691
+
692
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
693
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
694
+
695
+ range_tensor = torch.arange(llm_grid_t).view(-1, 1)
696
+ expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
697
+
698
+ ## normalize type, send to device.
699
+ second_per_grid_t = torch.as_tensor(
700
+ second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device
701
+ )
702
+
703
+ time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
704
+
705
+ time_tensor_long = time_tensor.long()
706
+ t_index = time_tensor_long.flatten()
707
+
708
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
709
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
710
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
711
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
712
+ if ed_image < ed_video:
713
+ # 如果当前是图片,则需要插入 compare_token_size 个图像对比的token的position
714
+ compare_t_index = t_index[-1].repeat(self.compare_token_size)
715
+ # compare_h_index = torch.arange(self.compare_token_size)
716
+ # compare_w_index = torch.arange(self.compare_token_size)
717
+ compare_h_index = compare_t_index
718
+ compare_w_index = compare_t_index
719
+ llm_pos_ids_list.append(torch.stack([compare_t_index, compare_h_index, compare_w_index]) + text_len + st_idx)
720
+ st = st + self.compare_token_size
721
+
722
+ if st < len(input_tokens):
723
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
724
+ text_len = len(input_tokens) - st
725
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
726
+
727
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
728
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
729
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
730
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
731
+ return position_ids, mrope_position_deltas
732
+ else:
733
+ if attention_mask is not None:
734
+ position_ids = attention_mask.long().cumsum(-1) - 1
735
+ position_ids.masked_fill_(attention_mask == 0, 1)
736
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
737
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
738
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
739
+ else:
740
+ position_ids = (
741
+ torch.arange(input_ids.shape[1], device=input_ids.device)
742
+ .view(1, 1, -1)
743
+ .expand(3, input_ids.shape[0], -1)
744
+ )
745
+ mrope_position_deltas = torch.zeros(
746
+ [input_ids.shape[0], 1],
747
+ device=input_ids.device,
748
+ dtype=input_ids.dtype,
749
+ )
750
+
751
+ return position_ids, mrope_position_deltas
752
+
753
+ class ADCopilotVLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
754
+ config_class = ADCopilotConfig
755
+
756
+ def __init__(self, config):
757
+ super().__init__(config)
758
+ self.model = ADCopilotVLModel(config)