shilinxu commited on
Commit
5cb7f5a
·
verified ·
1 Parent(s): e3116d0

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MoonVitPretrainedModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_moonvit.MoonViTConfig",
7
+ "AutoModel": "modeling_moonvit.MoonVitPretrainedModel"
8
+ },
9
+ "hidden_size": 1152,
10
+ "text_hidden_size": 2048,
11
+ "init_pos_emb_height": 64,
12
+ "init_pos_emb_width": 64,
13
+ "intermediate_size": 4304,
14
+ "merge_kernel_size": [
15
+ 2,
16
+ 2
17
+ ],
18
+ "model_type": "moonvit",
19
+ "num_attention_heads": 16,
20
+ "num_hidden_layers": 27,
21
+ "patch_size": 14,
22
+ "torch_dtype": "bfloat16",
23
+ "transformers_version": "4.52.1"
24
+ }
configuration_moonvit.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+ class MoonViTConfig(PretrainedConfig):
4
+ model_type = "moonvit"
5
+
6
+ def __init__(
7
+ self,
8
+ patch_size: int = 14,
9
+ init_pos_emb_height: int = 64,
10
+ init_pos_emb_width: int = 64,
11
+ num_attention_heads: int = 16,
12
+ num_hidden_layers: int = 27,
13
+ hidden_size: int = 1152,
14
+ text_hidden_size: int = 2048,
15
+ intermediate_size: int = 4304,
16
+ merge_kernel_size: tuple[int, int] = (2, 2),
17
+ **kwargs,
18
+ ):
19
+ super().__init__(**kwargs)
20
+ self.patch_size = patch_size
21
+ # Positional embedding config
22
+ self.init_pos_emb_height = init_pos_emb_height
23
+ self.init_pos_emb_width = init_pos_emb_width
24
+ # Transformer config
25
+ self.num_hidden_layers = num_hidden_layers
26
+ self.num_attention_heads = num_attention_heads
27
+ self.hidden_size = hidden_size
28
+ self.text_hidden_size = text_hidden_size
29
+ self.intermediate_size = intermediate_size
30
+ # Patch merger config
31
+ self.merge_kernel_size = merge_kernel_size
32
+
image_processing_moonvit.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import numpy as np
4
+ from PIL import Image
5
+ from typing import Optional, Union
6
+
7
+ import torch
8
+ from torchvision.transforms import functional as TF
9
+ from transformers.image_utils import ImageInput, make_list_of_images, valid_images
10
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
11
+ from transformers.utils import TensorType
12
+
13
+
14
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
15
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
16
+
17
+
18
+ class MoonViTImageProcessor(BaseImageProcessor):
19
+ model_type = "moonvit"
20
+
21
+ def __init__(
22
+ self,
23
+ patch_size: int = 14,
24
+ pad_input: bool = False,
25
+ image_mean: tuple[float, float, float] = OPENAI_DATASET_MEAN,
26
+ image_std: tuple[float, float, float] = OPENAI_DATASET_STD,
27
+ in_token_limit: int = 4096,
28
+ merge_kernel_size: list[int, int] = [2, 2],
29
+ **kwargs,
30
+ ):
31
+ super().__init__(**kwargs)
32
+ self.in_token_limit = in_token_limit
33
+ self.patch_size = patch_size
34
+ self.pad_input = pad_input
35
+ self.image_mean = image_mean
36
+ self.image_std = image_std
37
+ self.merge_kernel_size = merge_kernel_size
38
+
39
+ def rescale(
40
+ self, image: Image.Image, merge_kernel_size: list[int, int] = [2, 2]
41
+ ) -> Image.Image:
42
+ w, h = image.size
43
+ patch_size = self.patch_size
44
+
45
+ if (w // patch_size) * (h // patch_size) > self.in_token_limit:
46
+ scale = math.sqrt(self.in_token_limit / ((w // patch_size) * (h // patch_size)))
47
+ new_w, new_h = int(w * scale), int(h * scale)
48
+ image = image.resize((new_w, new_h), Image.Resampling.BICUBIC)
49
+ if self.pad_input:
50
+ new_w, new_h = image.size
51
+ pad_size_h = merge_kernel_size[0] * patch_size
52
+ pad_size_w = merge_kernel_size[1] * patch_size
53
+
54
+ pad_h = (pad_size_h - new_h % pad_size_h) % pad_size_h
55
+ pad_w = (pad_size_w - new_w % pad_size_w) % pad_size_w
56
+
57
+ image = TF.pad(image, (0, 0, pad_w, pad_h))
58
+ else:
59
+ new_w, new_h = image.size
60
+ new_w = new_w - new_w % patch_size
61
+ new_h = new_h - new_h % patch_size
62
+ image = TF.center_crop(image, (new_h, new_w))
63
+
64
+ w, h = image.size
65
+ if w // patch_size >= 512 or h // patch_size >= 512:
66
+ raise ValueError("Exceed pos emb")
67
+
68
+ return image
69
+
70
+ def to_tensor(self, image: Image.Image) -> torch.Tensor:
71
+ return TF.to_tensor(image.convert("RGB"))
72
+
73
+ def normalize(self, image: torch.Tensor) -> torch.Tensor:
74
+ return TF.normalize(image, self.image_mean, self.image_std)
75
+
76
+ def patchify(self, image: torch.Tensor) -> tuple[torch.Tensor, list[int, int]]:
77
+ patch_size = self.patch_size
78
+ C, H, W = image.shape
79
+ patches = image.reshape(C, H // patch_size, patch_size, W // patch_size, patch_size)
80
+ patches = patches.permute(1, 3, 0, 2, 4)
81
+ patches = patches.contiguous().view(-1, C, patch_size, patch_size)
82
+ grid_hw = (H // patch_size, W // patch_size)
83
+ return patches, grid_hw
84
+
85
+ def _preprocess(self, image: ImageInput) -> tuple[torch.Tensor, list[int, int]]:
86
+ """
87
+ Preprocess image and patchify it.
88
+
89
+ Args:
90
+ image (`ImageInput`):
91
+ Image to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
92
+
93
+ Returns:
94
+ patches: torch.Tensor
95
+ grid_hw: list[int, int]
96
+ """
97
+ image = self.rescale(image, self.merge_kernel_size)
98
+ image = self.to_tensor(image)
99
+ image = self.normalize(image)
100
+ patches, grid_hw = self.patchify(image)
101
+ return patches, grid_hw
102
+
103
+ def preprocess(
104
+ self,
105
+ images: ImageInput,
106
+ return_tensors: Optional[Union[str, TensorType]] = None,
107
+ ) -> BatchFeature:
108
+ images = make_list_of_images(images)
109
+
110
+ if not valid_images(images):
111
+ raise ValueError(
112
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
113
+ "torch.Tensor, tf.Tensor or jax.ndarray."
114
+ )
115
+
116
+ pixel_values, image_grid_hws = [], []
117
+ for image in images:
118
+ patches, image_grid_hw = self._preprocess(image)
119
+ pixel_values.append(patches)
120
+ image_grid_hws.append(image_grid_hw)
121
+ pixel_values = torch.concat(pixel_values, dim=0)
122
+ image_grid_hws = np.array(image_grid_hws)
123
+ data = {"pixel_values": pixel_values, "image_grid_hws": image_grid_hws}
124
+
125
+ return BatchFeature(data=data, tensor_type=return_tensors)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23e7427a23e6dc97a03f969e893ef57ac843b2df54ba9f6630ee333cf351744e
3
+ size 895125904
modeling_moonvit.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from copy import deepcopy
4
+ from typing import Union, Tuple, Sequence, Optional, List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers.activations import GELUActivation, ACT2FN, PytorchGELUTanh
10
+ from transformers.modeling_utils import PreTrainedModel
11
+ from transformers.utils import is_flash_attn_2_available
12
+
13
+ from .configuration_moonvit import MoonViTConfig
14
+
15
+ if is_flash_attn_2_available():
16
+ from flash_attn import flash_attn_varlen_func
17
+ else:
18
+ flash_attn_varlen_func = None
19
+
20
+
21
+ def multihead_attention(
22
+ q: torch.Tensor,
23
+ k: torch.Tensor,
24
+ v: torch.Tensor,
25
+ q_cu_seqlens: Optional[torch.Tensor] = None,
26
+ k_cu_seqlens: Optional[torch.Tensor] = None,
27
+ ):
28
+ """Multi-head attention using flash attention 2.
29
+ Args:
30
+ q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
31
+ or (tot_seqlens, num_heads, head_dim) if packing.
32
+ q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
33
+ The first element should be 0 and the last element should be q.shape[0].
34
+ k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
35
+ The first element should be 0 and the last element should be k.shape[0].
36
+ Returns:
37
+ output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
38
+ where dim = num_heads * head_dim
39
+ """
40
+ # Unified format legal check
41
+ assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
42
+ assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
43
+ assert (
44
+ k_cu_seqlens[-1] == k.shape[0] == v.shape[0]
45
+ ), "k_cu_seqlens must sum to k.shape[0]"
46
+ assert q.dtype in [
47
+ torch.bfloat16,
48
+ torch.float16,
49
+ ], f"unsupported dtype {q.dtype} for multihead attn"
50
+
51
+ max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
52
+ max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
53
+ attn_out = flash_attn_varlen_func(
54
+ q,
55
+ k,
56
+ v,
57
+ q_cu_seqlens,
58
+ k_cu_seqlens,
59
+ max_seqlen_q,
60
+ max_seqlen_k,
61
+ causal=False,
62
+ )
63
+ attn_out = attn_out.flatten(start_dim=-2)
64
+
65
+ return attn_out
66
+
67
+
68
+ def sdpa_attention(
69
+ q: torch.Tensor,
70
+ k: torch.Tensor,
71
+ v: torch.Tensor,
72
+ q_cu_seqlens: Optional[torch.Tensor] = None,
73
+ k_cu_seqlens: Optional[torch.Tensor] = None,
74
+ ) -> torch.Tensor:
75
+ """SDPA attention.
76
+ Args:
77
+ q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
78
+ or (tot_seqlens, num_heads, head_dim) if packing.
79
+ """
80
+ seq_length = q.shape[0]
81
+ attention_mask = torch.zeros(
82
+ [1, seq_length, seq_length], device=q.device, dtype=torch.bool
83
+ )
84
+ for i in range(1, len(q_cu_seqlens)):
85
+ attention_mask[
86
+ ...,
87
+ q_cu_seqlens[i - 1] : q_cu_seqlens[i],
88
+ q_cu_seqlens[i - 1] : q_cu_seqlens[i],
89
+ ] = True
90
+ q = q.transpose(0, 1)
91
+ k = k.transpose(0, 1)
92
+ v = v.transpose(0, 1)
93
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
94
+ attn_output = attn_output.transpose(0, 1)
95
+ attn_output = attn_output.reshape(seq_length, -1)
96
+ return attn_output
97
+
98
+
99
+ def eager_attention(
100
+ q: torch.Tensor,
101
+ k: torch.Tensor,
102
+ v: torch.Tensor,
103
+ q_cu_seqlens: Optional[torch.Tensor] = None,
104
+ k_cu_seqlens: Optional[torch.Tensor] = None,
105
+ ) -> torch.Tensor:
106
+ seq_length = q.shape[0]
107
+ attention_mask = torch.zeros(
108
+ [1, seq_length, seq_length], device=q.device, dtype=torch.bool
109
+ )
110
+ for i in range(1, len(q_cu_seqlens)):
111
+ attention_mask[
112
+ ...,
113
+ q_cu_seqlens[i - 1] : q_cu_seqlens[i],
114
+ q_cu_seqlens[i - 1] : q_cu_seqlens[i],
115
+ ] = True
116
+ q = q.transpose(0, 1)
117
+ k = k.transpose(0, 1)
118
+ v = v.transpose(0, 1)
119
+
120
+ attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1])
121
+ attn_weight += attention_mask
122
+ attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
123
+
124
+ attn_output = attn_weight @ v
125
+ attn_output = attn_output.transpose(0, 1)
126
+ attn_output = attn_output.reshape(seq_length, -1)
127
+ return attn_output
128
+
129
+
130
+ VL_VISION_ATTENTION_FUNCTIONS = {
131
+ "flash_attention_2": multihead_attention,
132
+ "sdpa": sdpa_attention,
133
+ "eager": eager_attention,
134
+ }
135
+
136
+
137
+ def _apply_rope_input_validation(x, freqs_cis):
138
+ assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
139
+ assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
140
+ assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
141
+ assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
142
+
143
+
144
+ def apply_rope(
145
+ xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
146
+ ) -> tuple[torch.Tensor, torch.Tensor]:
147
+ """
148
+ Args: (The leading dimensions of all inputs should be the same)
149
+ xq: query, tensor of shape (..., num_heads, head_dim)
150
+ xk: key, tensor of shape (..., num_heads, head_dim)
151
+ freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
152
+ Returns:
153
+ xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
154
+ """
155
+ _apply_rope_input_validation(xq, freqs_cis)
156
+ _apply_rope_input_validation(xk, freqs_cis)
157
+
158
+ freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
159
+ # ..., num_heads, head_dim/2
160
+ xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
161
+ xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
162
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
163
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
164
+ return xq_out.type_as(xq), xk_out.type_as(xk)
165
+
166
+
167
+ class Learnable2DInterpPosEmb(nn.Module):
168
+ def __init__(
169
+ self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic"
170
+ ) -> None:
171
+ super().__init__()
172
+ self.height = height
173
+ self.width = width
174
+ self.interpolation_mode = interpolation_mode
175
+ self.weight = nn.Parameter(torch.empty(height, width, dim))
176
+ self.reset_parameters()
177
+
178
+ def reset_parameters(self):
179
+ nn.init.normal_(self.weight)
180
+
181
+ def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
182
+ pos_embs = []
183
+ for shape in grid_hws.tolist():
184
+ if shape == self.weight.shape[:-1]:
185
+ pos_embs.append(self.weight.flatten(end_dim=1))
186
+ else:
187
+ pos_embs.append(
188
+ F.interpolate(
189
+ self.weight.permute((2, 0, 1)).unsqueeze(0),
190
+ size=shape,
191
+ mode=self.interpolation_mode,
192
+ )
193
+ .squeeze(0)
194
+ .permute((1, 2, 0))
195
+ .flatten(end_dim=1)
196
+ )
197
+ out = x + torch.cat(pos_embs)
198
+ return out
199
+
200
+
201
+ class MoonVisionPatchEmbed(nn.Module):
202
+
203
+ def __init__(
204
+ self,
205
+ out_dim: int,
206
+ in_dim: int = 3,
207
+ patch_size: Union[int, Tuple[int, int]] = (14, 14),
208
+ pos_emb_height: int = 14,
209
+ pos_emb_width: int = 14,
210
+ ):
211
+ super().__init__()
212
+ assert isinstance(
213
+ patch_size, (int, Sequence)
214
+ ), f"Invalid patch_size type: {type(patch_size)}"
215
+ if isinstance(patch_size, int):
216
+ patch_size = (patch_size, patch_size)
217
+ assert (
218
+ len(patch_size) == 2
219
+ ), f"Expected patch_size to be a tuple of 2, got {patch_size}"
220
+ self.patch_size = patch_size
221
+
222
+ self.proj = nn.Conv2d(
223
+ in_dim, out_dim, kernel_size=patch_size, stride=patch_size
224
+ )
225
+
226
+ self.pos_emb = Learnable2DInterpPosEmb(
227
+ height=pos_emb_height, width=pos_emb_width, dim=out_dim
228
+ )
229
+
230
+ def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
231
+ """
232
+ Args:
233
+ x (L, Channels): input tensor
234
+ grid_hws (N, 2): grid height and width
235
+ Returns:
236
+ (L, Cout) tensor
237
+ """
238
+ x = self.proj(x).view(x.size(0), -1)
239
+ # apply positional embedding
240
+ x = self.pos_emb(x, grid_hws)
241
+ return x
242
+
243
+
244
+ class Rope2DPosEmb(nn.Module):
245
+ """2D rotary position embedding with multi-resolution support.
246
+ This class is intended to be used in the following way:
247
+ 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
248
+ 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
249
+ 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
250
+ The rope is shared across all attention layers and all heads.
251
+ Refs:
252
+ - RoFormer: https://arxiv.org/abs/2104.09864
253
+ - VisionLLaMA: https://arxiv.org/abs/2403.00522
254
+ - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
255
+ Args:
256
+ dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
257
+ max_height (int): the maximum height of the 2D grid
258
+ max_width (int): the maximum width of the 2D grid
259
+ theta_base (float): the base of the theta
260
+ device (str): the device to store the precomputed cis
261
+ """
262
+
263
+ def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000):
264
+ super().__init__()
265
+ self.dim = dim
266
+ assert self.dim % 4 == 0, "dim must be divisible by 4"
267
+ self.max_height = max_height
268
+ self.max_width = max_width
269
+ self.theta_base = theta_base
270
+
271
+ self.freqs_cis = None
272
+
273
+ def extra_repr(self):
274
+ return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
275
+
276
+ def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor:
277
+ """Calculate the cis(freqs) for each position in the 2D grid.
278
+ Return: complex tensor of shape (max_height, max_width, dim//2) and value:
279
+ height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
280
+ weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
281
+ note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
282
+ """
283
+ N = self.max_height * self.max_width
284
+ flat_pos = torch.arange(0, N).float().to(device)
285
+ x_pos = flat_pos % self.max_width
286
+ y_pos = flat_pos // self.max_width
287
+ dim_range = (
288
+ torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(device)
289
+ ) # C/4
290
+ freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
291
+ x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
292
+ y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
293
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
294
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
295
+ # N, C/4, 2
296
+ freqs_cis = torch.cat(
297
+ [x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1
298
+ )
299
+ # max_height, max_width, C/2
300
+ freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
301
+ return freqs_cis
302
+
303
+ def get_freqs_cis(self, grid_hws: torch.Tensor) -> torch.Tensor:
304
+ """
305
+ Args:
306
+ grid_hws (torch.Tensor): grid height and width
307
+ Returns:
308
+ freqs_cis: tensor of shape (sum(t * height * width), dim//2)
309
+ """
310
+ if self.freqs_cis is None:
311
+ self.freqs_cis = self._precompute_freqs_cis(grid_hws.device)
312
+
313
+ shapes = grid_hws.tolist()
314
+ assert all(
315
+ 1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
316
+ ), (
317
+ shapes,
318
+ self.max_height,
319
+ self.max_width,
320
+ )
321
+ freqs_cis = torch.cat(
322
+ [self.freqs_cis[:h, :w].reshape(-1, self.dim // 2) for h, w in shapes],
323
+ dim=0,
324
+ )
325
+ return freqs_cis
326
+
327
+
328
+ class MLP2(nn.Module):
329
+ """
330
+ Args:
331
+ dims: [in_dim, hidden_dim, out_dim]
332
+ bias: whether to use bias in linear layer.
333
+ """
334
+
335
+ def __init__(self, dims: list[int], activation, bias=True):
336
+ super().__init__()
337
+ assert len(dims) == 3
338
+ self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
339
+ self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
340
+ self.activation = activation
341
+ for m in [self.fc0, self.fc1]:
342
+ nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
343
+ if m.bias is not None:
344
+ nn.init.zeros_(m.bias)
345
+
346
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
347
+ x = self.fc0(x)
348
+ x = self.activation(x)
349
+ return self.fc1(x)
350
+
351
+
352
+ class MoonVitEncoderLayer(nn.Module):
353
+
354
+ def __init__(
355
+ self,
356
+ num_heads: int,
357
+ hidden_dim: int,
358
+ mlp_dim: int,
359
+ *,
360
+ attn_implementation: str = "eager",
361
+ activation=F.gelu,
362
+ attn_bias: bool = False,
363
+ ):
364
+ super().__init__()
365
+ self.num_heads = num_heads
366
+ self.hidden_dim = hidden_dim
367
+ self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
368
+ self.attn_implementation = attn_implementation
369
+
370
+ self.norm0 = nn.LayerNorm(hidden_dim)
371
+ self.norm1 = nn.LayerNorm(hidden_dim)
372
+ self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
373
+ self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
374
+ self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
375
+
376
+ def attention_qkvpacked(
377
+ self,
378
+ x: torch.Tensor,
379
+ cu_seqlens: torch.Tensor,
380
+ rope_freqs_cis: Optional[torch.Tensor] = None,
381
+ ):
382
+ """
383
+ Args:
384
+ x (torch.Tensor): (batch_size, seqlen, hidden_dim)
385
+ cu_seqlens (torch.Tensor):
386
+ """
387
+ xqkv = self.wqkv(x)
388
+
389
+ qkv_shape = xqkv.size()[:-1] + (
390
+ 3,
391
+ self.num_heads,
392
+ self.hidden_size_per_attention_head,
393
+ )
394
+ # xqkv: (batch_size, seqlen, 3, nheads, headdim)
395
+ xqkv = xqkv.view(*qkv_shape)
396
+ xq, xk, xv = torch.unbind(xqkv, dim=-3)
397
+
398
+ xq, xk = apply_rope(xq, xk, rope_freqs_cis)
399
+
400
+ attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
401
+ attn_out = attn_func(
402
+ xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens
403
+ )
404
+
405
+ attn_out = self.wo(attn_out)
406
+ return attn_out
407
+
408
+ def forward(
409
+ self,
410
+ hidden_states: torch.Tensor,
411
+ cu_seqlens: torch.Tensor,
412
+ rope_freqs_cis: Union[torch.Tensor, None] = None,
413
+ ) -> torch.Tensor:
414
+ """
415
+ Args:
416
+ hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set
417
+ Returns:
418
+ output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input
419
+ """
420
+ residual = hidden_states
421
+ hidden_states = self.norm0(hidden_states)
422
+ attn_out = self.attention_qkvpacked(
423
+ hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
424
+ )
425
+ hidden_states = residual + attn_out
426
+
427
+ residual = hidden_states
428
+ hidden_states = self.mlp(self.norm1(hidden_states))
429
+ hidden_states = residual + hidden_states
430
+ return hidden_states
431
+
432
+
433
+ class MoonVitEncoder(nn.Module):
434
+
435
+ def __init__(
436
+ self,
437
+ hidden_dim: int,
438
+ num_layers: int,
439
+ block_cfg: dict,
440
+ ) -> None:
441
+ super().__init__()
442
+
443
+ self.rope_2d = Rope2DPosEmb(
444
+ block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512
445
+ )
446
+ self.blocks = nn.ModuleList(
447
+ [MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)]
448
+ )
449
+ self.final_layernorm = nn.LayerNorm(hidden_dim)
450
+
451
+ def forward(
452
+ self, hidden_states: torch.Tensor, grid_hws: torch.Tensor
453
+ ) -> torch.Tensor:
454
+ rope_freqs_cis = self.rope_2d.get_freqs_cis(grid_hws=grid_hws)
455
+
456
+ lengths = torch.cat(
457
+ (
458
+ torch.zeros(1, device=hidden_states.device, dtype=grid_hws.dtype),
459
+ grid_hws[:, 0] * grid_hws[:, 1],
460
+ )
461
+ )
462
+ cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
463
+
464
+ for _, block in enumerate(self.blocks):
465
+ hidden_states = block(
466
+ hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
467
+ )
468
+
469
+ hidden_states = self.final_layernorm(hidden_states)
470
+
471
+ return hidden_states
472
+
473
+
474
+ def patch_merger(
475
+ x: torch.Tensor,
476
+ grid_hws: torch.Tensor,
477
+ merge_kernel_size: list[int, int] = (2, 2),
478
+ ) -> List[torch.Tensor]:
479
+ d_model = x.size(-1)
480
+
481
+ outputs = []
482
+ pre_sum = 0
483
+ for x_shape in grid_hws.tolist():
484
+ height, width = x_shape[0], x_shape[1]
485
+ # Get the current sequence
486
+ seq = x[pre_sum : pre_sum + height * width]
487
+ # Reshape along self.merge_kernel_size and concat to the last dimension
488
+ kernel_height, kernel_width = merge_kernel_size
489
+ new_height, new_width = height // kernel_height, width // kernel_width
490
+ reshaped_seq = seq.view(
491
+ new_height, kernel_height, new_width, kernel_width, d_model
492
+ )
493
+ reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous()
494
+ padded_seq = reshaped_seq.view(
495
+ new_height * new_width, kernel_height * kernel_width, -1
496
+ )
497
+ outputs.append(padded_seq)
498
+ pre_sum += height * width
499
+
500
+ return outputs
501
+
502
+ class MoonVitVLProjector(nn.Module):
503
+
504
+ def __init__(
505
+ self,
506
+ in_channels: int,
507
+ merge_kernel_size: list[int, int],
508
+ hidden_act: str = "gelu",
509
+ ln_eps: float = 1e-5,
510
+ out_dim: int = 4096,
511
+ ):
512
+ super().__init__()
513
+ self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1]
514
+
515
+ self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
516
+ self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
517
+ self.act = ACT2FN[hidden_act]
518
+ self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)
519
+
520
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
521
+ hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
522
+ hidden_states = self.linear_1(hidden_states)
523
+ hidden_states = self.act(hidden_states)
524
+ hidden_states = self.linear_2(hidden_states)
525
+ return hidden_states
526
+
527
+
528
+ class MultiModalProjector(nn.Module):
529
+
530
+ def __init__(self, config):
531
+ super().__init__()
532
+
533
+ self.hidden_size = (
534
+ config.hidden_size
535
+ * config.merge_kernel_size[0]
536
+ * config.merge_kernel_size[1]
537
+ )
538
+
539
+ self.pre_norm = torch.nn.LayerNorm(config.hidden_size, eps=1e-05)
540
+ self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
541
+ self.act = GELUActivation()
542
+ self.linear_2 = nn.Linear(
543
+ self.hidden_size, config.text_hidden_size, bias=True
544
+ )
545
+
546
+ def forward(self, image_features: list[torch.Tensor]) -> torch.Tensor:
547
+ image_features = torch.cat(image_features, dim=0)
548
+ hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
549
+ hidden_states = self.linear_1(hidden_states)
550
+ hidden_states = self.act(hidden_states)
551
+ hidden_states = self.linear_2(hidden_states)
552
+
553
+ return hidden_states
554
+
555
+
556
+ class MoonVitPretrainedModel(PreTrainedModel):
557
+ config_class = MoonViTConfig
558
+ model_type = "moonvit"
559
+ _no_split_modules = ["PackingTransformer"]
560
+ _supports_flash_attn_2 = True
561
+ _supports_sdpa = True
562
+
563
+ def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
564
+ super().__init__(config, *inputs, **kwargs)
565
+ config = deepcopy(config)
566
+ self.merge_kernel_size = config.merge_kernel_size
567
+ self.patch_size = config.patch_size
568
+ self.patch_embed = MoonVisionPatchEmbed(
569
+ out_dim=config.hidden_size,
570
+ patch_size=config.patch_size,
571
+ pos_emb_height=config.init_pos_emb_height,
572
+ pos_emb_width=config.init_pos_emb_width,
573
+ )
574
+
575
+ self.encoder = MoonVitEncoder(
576
+ hidden_dim=config.hidden_size,
577
+ num_layers=config.num_hidden_layers,
578
+ block_cfg={
579
+ "num_heads": config.num_attention_heads,
580
+ "hidden_dim": config.hidden_size,
581
+ "mlp_dim": config.intermediate_size,
582
+ "activation": PytorchGELUTanh(),
583
+ "attn_bias": True,
584
+ "attn_implementation": config._attn_implementation,
585
+ },
586
+ )
587
+ self.multi_modal_projector = MultiModalProjector(config)
588
+
589
+ def forward(
590
+ self, pixel_values: torch.Tensor, grid_hws: torch.Tensor
591
+ ) -> torch.Tensor:
592
+ """
593
+ Args:
594
+ pixel_values (torch.Tensor): The input pixel values.
595
+ grid_hws (torch.Tensor): The grid height and width.
596
+ Returns:
597
+ torch.Tensor: The output tokens.
598
+ """
599
+ hidden_states = self.patch_embed(pixel_values, grid_hws)
600
+ hidden_states = self.encoder(hidden_states, grid_hws)
601
+ hidden_states = patch_merger(
602
+ hidden_states, grid_hws, merge_kernel_size=self.merge_kernel_size
603
+ )
604
+ hidden_states = self.multi_modal_projector(hidden_states)
605
+ return hidden_states
606
+
preprocessor_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_auto_class": "AutoImageProcessor",
3
+ "auto_map": {
4
+ "AutoImageProcessor": "image_processing_moonvit.MoonViTImageProcessor"
5
+ },
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_std": [
12
+ 0.5,
13
+ 0.5,
14
+ 0.5
15
+ ],
16
+ "in_token_limit": 16384,
17
+ "merge_kernel_size": [
18
+ 2,
19
+ 2
20
+ ],
21
+ "num_pooled_tokens": 1024,
22
+ "pad_input": true,
23
+ "patch_size": 14
24
+ }