| from transformers import PreTrainedModel |
| from .configuration_vitmix import ViTMixConfig |
|
|
|
|
| |
| import torch |
| from torch import nn |
|
|
| from einops import rearrange |
| from einops.layers.torch import Rearrange |
|
|
| from st_moe_pytorch import SparseMoEBlock, MoE |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def pair(t): |
| return t if isinstance(t, tuple) else (t, t) |
|
|
| def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): |
| y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") |
| assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" |
| omega = torch.arange(dim // 4) / (dim // 4 - 1) |
| omega = 1.0 / (temperature ** omega) |
|
|
| y = y.flatten()[:, None] * omega[None, :] |
| x = x.flatten()[:, None] * omega[None, :] |
| pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) |
| return pe.type(dtype) |
|
|
| |
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim, hidden_dim): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.LayerNorm(dim), |
| nn.Linear(dim, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, dim), |
| ) |
| def forward(self, x): |
| return self.net(x) |
|
|
| class Attention(nn.Module): |
| def __init__(self, dim, heads = 8, dim_head = 64): |
| super().__init__() |
| inner_dim = dim_head * heads |
| self.heads = heads |
| self.scale = dim_head ** -0.5 |
| self.norm = nn.LayerNorm(dim) |
|
|
| self.attend = nn.Softmax(dim = -1) |
|
|
| self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) |
| self.to_out = nn.Linear(inner_dim, dim, bias = False) |
|
|
| def forward(self, x): |
| x = self.norm(x) |
|
|
| qkv = self.to_qkv(x).chunk(3, dim = -1) |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) |
|
|
| dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
|
|
| attn = self.attend(dots) |
|
|
| out = torch.matmul(attn, v) |
| out = rearrange(out, 'b h n d -> b n (h d)') |
| return self.to_out(out) |
|
|
| class Transformer(nn.Module): |
| |
| def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_experts): |
| super().__init__() |
| self.norm = nn.LayerNorm(dim) |
| self.layers = nn.ModuleList([]) |
| for _ in range(depth): |
| if _ % 2 == 0: |
| self.layers.append(nn.ModuleList([ |
| Attention(dim, heads = heads, dim_head = dim_head), |
| FeedForward(dim, mlp_dim) |
| ])) |
| else: |
| self.layers.append(nn.ModuleList([ |
| Attention(dim, heads = heads, dim_head = dim_head), |
| SparseMoEBlock( |
| MoE(dim = dim, |
| num_experts = num_experts, |
| gating_top_n = 2, |
| threshold_train = 0.2, |
| threshold_eval = 0.2, |
| capacity_factor_train = 1.25, |
| capacity_factor_eval = 2., |
| balance_loss_coef = 1e-2, |
| router_z_loss_coef = 1e-3, |
| ), |
| add_ff_before = True, |
| add_ff_after = True |
| ) |
| ])) |
| def forward(self, x): |
| for attne, ff in self.layers: |
| x = attne(x) + x |
| try: |
| x = ff(x) + x |
| except: |
| x = ff(x)[0]+x |
| return self.norm(x) |
|
|
| class SimpleViTMIX(nn.Module): |
| def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, num_experts = 12): |
| super().__init__() |
| image_height, image_width = pair(image_size) |
| patch_height, patch_width = pair(patch_size) |
|
|
| assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' |
|
|
| patch_dim = channels * patch_height * patch_width |
|
|
| self.to_patch_embedding = nn.Sequential( |
| Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), |
| nn.LayerNorm(patch_dim), |
| nn.Linear(patch_dim, dim), |
| nn.LayerNorm(dim), |
| ) |
|
|
| self.pos_embedding = posemb_sincos_2d( |
| h = image_height // patch_height, |
| w = image_width // patch_width, |
| dim = dim, |
| ) |
|
|
| self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_experts) |
|
|
| self.pool = "mean" |
| self.to_latent = nn.Identity() |
|
|
| self.linear_head = nn.Linear(dim, num_classes) |
|
|
| def forward(self, img): |
| device = img.device |
|
|
| x = self.to_patch_embedding(img) |
| x += self.pos_embedding.to(device, dtype=x.dtype) |
|
|
| x = self.transformer(x) |
| x = x.mean(dim = 1) |
|
|
| x = self.to_latent(x) |
| return self.linear_head(x) |
|
|
|
|
|
|
| class ViTMixModel(PreTrainedModel): |
| config_class = ViTMixConfig |
| def __init__(self, config): |
| super().__init__(config) |
| self.model = SimpleViTMIX( |
| image_size = config.image_size, |
| patch_size = config.patch_size, |
| num_classes = config.num_classes, |
| dim = config.dim, |
| depth = config.depth, |
| heads = config.heads, |
| mlp_dim = config.mlp_dim, |
| num_experts = config.num_experts |
| ) |
| def forward(self,tensor, labels = None): |
| logits = self.model(tensor) |
| if labels is not None: |
| loss = torch.nn.cross_entropy(logits, labels) |
| return {"loss": loss, "logits": logits} |
| return {"logits": logits} |