| import torch |
| import torch.nn as nn |
| from typing import Callable, Tuple |
|
|
|
|
| def bipartite_soft_matching( |
| metric: torch.Tensor, |
| r: int, |
| ) -> Tuple[Callable, Callable]: |
| """ |
| Applies ToMe with a balanced matching set (50%, 50%). |
| |
| Input size is [batch, tokens, channels]. |
| r indicates the number of tokens to remove (max 50% of tokens). |
| """ |
| protected = 0 |
|
|
| t = metric.shape[1] |
| r = min(r, (t - protected) // 2) |
|
|
| assert r > 0, r |
|
|
| with torch.no_grad(): |
| metric = metric / metric.norm(dim=-1, keepdim=True) |
| a, b = metric[..., ::2, :], metric[..., 1::2, :] |
| scores = a @ b.transpose(-1, -2) |
|
|
| node_max, node_idx = scores.max(dim=-1) |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
| unm_idx = edge_idx[..., r:, :] |
| src_idx = edge_idx[..., :r, :] |
| dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) |
|
|
| def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: |
| src, dst = x[..., ::2, :], x[..., 1::2, :] |
| n, t1, c = src.shape |
| unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) |
| src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) |
| dst = dst.scatter_add(-2, dst_idx.expand(n, r, c), src) |
|
|
| return torch.cat([unm, dst], dim=1) |
|
|
| def unmerge(x: torch.Tensor) -> torch.Tensor: |
| unm_len = unm_idx.shape[1] |
| unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] |
| n, _, c = unm.shape |
|
|
| src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)) |
|
|
| out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype) |
|
|
| out[..., 1::2, :] = dst |
| out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm) |
| out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src) |
|
|
| return out |
|
|
| return merge, unmerge |
|
|
|
|
| def merge_wavg( |
| merge: Callable, x: torch.Tensor, size: torch.Tensor = None |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Applies the merge function by taking a weighted average based on token size. |
| Returns the merged tensor and the new token sizes. |
| """ |
| if size is None: |
| size = torch.ones_like(x[..., 0, None]) |
|
|
| x = merge(x * size, mode="sum") |
| size = merge(size, mode="sum") |
|
|
| x = x / size |
| return x, size |
|
|
|
|
|
|
|
|
| class ToMe16_mlp_hd64(nn.Module): |
| def __init__(self, config, vision_cfg): |
| super().__init__() |
| self._config = config |
| self.mm_hidden_size = config.mm_hidden_size |
| self.hw = vision_cfg.image_size // vision_cfg.patch_size |
| self.num_attention_heads = vision_cfg.num_attention_heads |
| self.mlp = nn.Sequential(nn.Linear(config.mm_hidden_size, config.hidden_size), |
| nn.GELU(), |
| nn.Linear(config.hidden_size, config.hidden_size)) |
| self.max_pos_hw = self.hw |
| self.max_pos_num_frames = config.mm_pos_num_frames |
| self.num_image_patches_per_side = 8 |
| self.num_frame_patches_per_side = 4 |
| |
| def merge_tokens(self, x, target_num_token): |
| r""" |
| x = torch.randn(10, 2560, c) |
| x = merge_tokens(x, r_merge_list=[1280]) |
| """ |
| size = None |
| b, p, c = x.shape |
| tmp_p = p |
| r_merge_list = [] |
| assert tmp_p > target_num_token, f"{tmp_p} should greater than {target_num_token}" |
| while tmp_p != target_num_token: |
| if tmp_p - target_num_token <= (tmp_p // 2): |
| r_merge_list.append(tmp_p - target_num_token) |
| break |
| else: |
| r_merge_list.append(tmp_p // 2) |
| tmp_p = tmp_p - (tmp_p // 2) |
| |
| |
| head = self.num_attention_heads |
|
|
| dim = c // head |
| for r in r_merge_list: |
| metric = x.reshape(b, p, head, dim).mean(2) |
| merge, _ = bipartite_soft_matching( |
| metric, |
| r |
| ) |
| x, size = merge_wavg(merge, x, size) |
| _, p, _ = x.shape |
|
|
| return x |
|
|
|
|
|
|
| def forward(self, x, compress=False, local_num_frames=-1): |
| height = width = self.hw |
| assert height * width == x.shape[1] |
|
|
| if local_num_frames != -1 and local_num_frames != 1: |
| assert compress is True |
| if compress: |
| if local_num_frames != -1: |
| num_frames = local_num_frames |
| x = x.reshape(x.shape[0] // local_num_frames, -1, x.shape[-1]) |
| else: |
| num_frames = x.shape[0] |
| x = x.reshape(1, -1, x.shape[-1]) |
| num_tome_tokens = 16 * num_frames |
| else: |
| num_tome_tokens = 64 |
| |
| x = self.merge_tokens(x, target_num_token=num_tome_tokens) |
| x = self.mlp(x) |
| return x |
|
|
| @property |
| def config(self): |
| return {"mm_projector_type": "tome16_mlp_hd64"} |
|
|
|
|
|
|
|
|
| def build_vision_projector(config, delay_load=False, **kwargs): |
| projector_type = getattr(config, "mm_projector_type", "linear") |
|
|
| if projector_type == 'tome16_mlp_hd64': |
| return ToMe16_mlp_hd64(config, kwargs["vision_cfg"]) |
|
|
| raise ValueError(f"Unknown projector type: {projector_type}") |
|
|