File size: 6,049 Bytes
3556c94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import torch
from torch import nn
import math
from fractions import Fraction
from transformers.models.blip_2.configuration_blip_2 import Blip2QFormerConfig
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel
class InterpolateDownsampler:
"""Spatial downsampling via area interpolation."""
def __init__(self, config, mode="area"):
self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size
self.new_image_side = int(self.orig_image_side * Fraction(config.downsample_rate))
self.mode = mode
def __call__(self, image_features):
batch_size, _, dim = image_features.size()
up_shape = [batch_size] + [self.orig_image_side] * 2 + [dim]
large_image_permuted = image_features.view(up_shape).permute(0,3,1,2)
small_image_permuted = torch.nn.functional.interpolate(
large_image_permuted, size=(self.new_image_side, self.new_image_side),
mode=self.mode,
)
final = small_image_permuted.permute(0,2,3,1).flatten(1,2)
return final
class SpatialOffsetDownsampler:
"""
Downsampler that samples one position from each 2x2 block across the image.
Maintains full spatial coverage while creating local continuity.
"""
def __init__(self, config, offset=0):
"""
Args:
config: Model configuration
offset: Integer offset (0, 1, 2, or 3) for position within each 2x2 block
0: top-left, 1: top-right, 2: bottom-left, 3: bottom-right
"""
self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size
self.new_image_side = self.orig_image_side // 2
self.offset = offset
self.offsets = [(0, 0), (0, 1), (1, 0), (1, 1)]
self.offset_h, self.offset_w = self.offsets[offset]
def __call__(self, image_features):
batch_size, seq_len, hidden_dim = image_features.shape
features_2d = image_features.reshape(batch_size, self.orig_image_side, self.orig_image_side, hidden_dim)
n_blocks = self.new_image_side
features_blocks = features_2d.reshape(
batch_size, n_blocks, 2, n_blocks, 2, hidden_dim
)
sampled = features_blocks[:, :, self.offset_h, :, self.offset_w, :]
sampled = sampled.reshape(batch_size, -1, hidden_dim)
return sampled
class WindowQFormerDownsampler(nn.Module):
"""Window-based QFormer downsampler that processes image patches in windows."""
def __init__(self, config, spatial_offset=None):
super().__init__()
llm_hidden_size = config.text_config.hidden_size
vision_hidden_size = config.vision_config.hidden_size
self.dropout = nn.Dropout(config.projector_dropout)
if spatial_offset is not None:
self.downsampler = SpatialOffsetDownsampler(config, offset=spatial_offset)
else:
self.downsampler = InterpolateDownsampler(config)
configuration = Blip2QFormerConfig(
hidden_size=vision_hidden_size,
num_attention_heads=vision_hidden_size // 64,
intermediate_size=3072,
num_hidden_layers=1,
encoder_hidden_size=vision_hidden_size,
cross_attention_frequency=1,
max_position_embeddings=2048,
use_qformer_text_input=False,
)
self.qformer = Blip2QFormerModel(configuration)
self.image_side = config.vision_config.image_size // config.vision_config.patch_size
q, w = config.downsample_rate.split("/")
self.query_side, self.window_side = int(q), int(w)
self.query_length = self.query_side ** 2
embed_std = 1 / math.sqrt(vision_hidden_size)
self.norm = nn.LayerNorm(vision_hidden_size, eps=1e-6)
self.query = nn.Parameter(torch.randn(1, self.query_length, vision_hidden_size) * embed_std)
self.image_positions = nn.Parameter(torch.randn(1, self.window_side ** 2, vision_hidden_size) * embed_std)
self.out_linear = nn.Linear(vision_hidden_size, llm_hidden_size, bias=True)
def _win(self, x, side, win):
"""
(B, side*side, C) raster -> (B*n*n, win*win, C) where n=side//win
windows are raster-ordered, and tokens inside each window are raster-ordered.
"""
B, _, C = x.shape
n = side // win
return (
x.view(B, side, side, C)
.view(B, n, win, n, win, C)
.transpose(2, 3) # (B, n, n, win, win, C)
.flatten(0, 2) # (B*n*n, win, win, C)
.flatten(1, 2) # (B*n*n, win*win, C)
)
def _unwin(self, xw, n, win):
"""
(B*n*n, win*win, C) -> (B, (n*win)^2, C) raster
"""
Bnn, _, C = xw.shape
assert Bnn % (n * n) == 0
B = Bnn // (n * n)
side = n * win
return (
xw.view(B, n, n, win, win, C)
.transpose(2, 3) # (B, n, win, n, win, C)
.contiguous()
.view(B, side, side, C)
.flatten(1, 2)
)
def forward(self, image_features):
B, HW, C = image_features.shape
assert HW == self.image_side * self.image_side
n = self.image_side // self.window_side
image_features = self.norm(image_features)
enc = self._win(image_features, self.image_side, self.window_side)
downsampled = self.downsampler(image_features)
new_side = n * self.query_side
downsampled_w = self._win(downsampled, new_side, self.query_side)
query_embeds = self.query + downsampled_w
encoder_embeds = self.dropout(enc + self.image_positions)
out_w = self.qformer(
query_embeds=query_embeds,
encoder_hidden_states=encoder_embeds,
return_dict=True,
).last_hidden_state
out = self._unwin(out_w, n=n, win=self.query_side)
out = self.dropout(out)
return self.out_linear(out)
|