| from functools import partial |
|
|
| import torch |
| import torch.nn.functional as F |
| from einops import rearrange |
| from rotary_embedding_torch import RotaryEmbedding, broadcat |
| from torch import nn |
|
|
|
|
| |
|
|
|
|
| def exists(val): |
| return val is not None |
|
|
|
|
| def default(val, d): |
| return val if exists(val) else d |
|
|
|
|
| def cast_tuple(val, depth = 1): |
| if isinstance(val, list): |
| val = tuple(val) |
| return val if isinstance(val, tuple) else (val,) * depth |
|
|
|
|
| def max_neg_value(t): |
| return -torch.finfo(t.dtype).max |
|
|
|
|
| def stable_softmax(t, dim = -1, alpha = 32 ** 2): |
| t = t / alpha |
| t = t - torch.amax(t, dim = dim, keepdim = True).detach() |
| return (t * alpha).softmax(dim = dim) |
|
|
|
|
| def route_args(router, args, depth): |
| routed_args = [(dict(), dict()) for _ in range(depth)] |
| matched_keys = [key for key in args.keys() if key in router] |
|
|
| for key in matched_keys: |
| val = args[key] |
| for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): |
| new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) |
| routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) |
| return routed_args |
|
|
|
|
| |
| class SequentialSequence(nn.Module): |
| def __init__(self, layers, args_route = {}, layer_dropout = 0.): |
| super().__init__() |
| assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' |
| self.layers = layers |
| self.args_route = args_route |
| self.layer_dropout = layer_dropout |
|
|
| def forward(self, x, **kwargs): |
| args = route_args(self.args_route, kwargs, len(self.layers)) |
| layers_and_args = list(zip(self.layers, args)) |
|
|
| for (f, g), (f_args, g_args) in layers_and_args: |
| x = x + f(x, **f_args) |
| x = x + g(x, **g_args) |
| return x |
|
|
|
|
| class DivideMax(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, x): |
| maxes = x.amax(dim = self.dim, keepdim = True).detach() |
| return x / maxes |
|
|
|
|
| |
| class LayerScale(nn.Module): |
| def __init__(self, dim, depth, fn): |
| super().__init__() |
| if depth <= 18: |
| init_eps = 0.1 |
| elif depth > 18 and depth <= 24: |
| init_eps = 1e-5 |
| else: |
| init_eps = 1e-6 |
|
|
| scale = torch.zeros(1, 1, dim).fill_(init_eps) |
| self.scale = nn.Parameter(scale) |
| self.fn = fn |
| def forward(self, x, **kwargs): |
| return self.fn(x, **kwargs) * self.scale |
|
|
| |
|
|
|
|
| class PreNorm(nn.Module): |
| def __init__(self, dim, fn, sandwich = False): |
| super().__init__() |
| self.norm = nn.LayerNorm(dim) |
| self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() |
| self.fn = fn |
|
|
| def forward(self, x, **kwargs): |
| x = self.norm(x) |
| x = self.fn(x, **kwargs) |
| return self.norm_out(x) |
|
|
| |
|
|
|
|
| class GEGLU(nn.Module): |
| def forward(self, x): |
| x, gates = x.chunk(2, dim = -1) |
| return x * F.gelu(gates) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim, dropout = 0., mult = 4.): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(dim, dim * mult * 2), |
| GEGLU(), |
| nn.Dropout(dropout), |
| nn.Linear(dim * mult, dim) |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
| |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.): |
| super().__init__() |
| inner_dim = dim_head * heads |
| self.heads = heads |
| self.seq_len = seq_len |
| self.scale = dim_head ** -0.5 |
|
|
| self.causal = causal |
|
|
| self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) |
| self.to_out = nn.Sequential( |
| nn.Linear(inner_dim, dim), |
| nn.Dropout(dropout) |
| ) |
|
|
| def forward(self, x, mask = None): |
| b, n, _, h, device = *x.shape, self.heads, x.device |
| softmax = torch.softmax |
|
|
| qkv = self.to_qkv(x).chunk(3, dim = -1) |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) |
|
|
| q = q * self.scale |
|
|
| dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) |
| mask_value = max_neg_value(dots) |
|
|
| if exists(mask): |
| mask = rearrange(mask, 'b j -> b () () j') |
| dots.masked_fill_(~mask, mask_value) |
| del mask |
|
|
| if self.causal: |
| i, j = dots.shape[-2:] |
| mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() |
| dots.masked_fill_(mask, mask_value) |
|
|
| attn = softmax(dots, dim=-1) |
|
|
| out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) |
| out = rearrange(out, 'b h n d -> b n (h d)') |
| out = self.to_out(out) |
| return out |
|
|
|
|
| |
| class Transformer(nn.Module): |
| def __init__( |
| self, |
| *, |
| dim, |
| depth, |
| seq_len, |
| causal = True, |
| heads = 8, |
| dim_head = 64, |
| ff_mult = 4, |
| attn_dropout = 0., |
| ff_dropout = 0., |
| sparse_attn = False, |
| sandwich_norm = False, |
| ): |
| super().__init__() |
| layers = nn.ModuleList([]) |
| sparse_layer = cast_tuple(sparse_attn, depth) |
|
|
| for ind, sparse_attn in zip(range(depth), sparse_layer): |
| attn = Attention(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) |
|
|
| ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout) |
|
|
| layers.append(nn.ModuleList([ |
| LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)), |
| LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm)) |
| ])) |
|
|
| execute_type = SequentialSequence |
| route_attn = ((True, False),) * depth |
| attn_route_map = {'mask': route_attn} |
|
|
| self.layers = execute_type(layers, args_route = attn_route_map) |
|
|
| def forward(self, x, **kwargs): |
| return self.layers(x, **kwargs) |