| |
| |
|
|
| import torch.nn as nn |
| from einops import rearrange |
| import opt_einsum as oe |
|
|
| contract = oe.contract |
| from .hyena_utils import HyenaFilter |
|
|
|
|
| class MonarchMixerSequenceMixing(nn.Module): |
| def __init__( |
| self, |
| d_model, |
| l_max=128, |
| dropout=0.0, |
| hyena_kernel_lr=None, |
| bidirectional=False, |
| hyena_lr_pos_emb=1e-5, |
| hyena_w=10, |
| hyena_w_mod=1, |
| hyena_wd=0.1, |
| hyena_emb_dim=3, |
| hyena_filter_dropout=0.0, |
| hyena_filter_order=16, |
| residual_long_conv=False, |
| hyena_training_additions=False, |
| ): |
| super().__init__() |
|
|
| self.d_model = d_model |
| self.l_max = l_max |
| self.kernel_lr = hyena_kernel_lr |
| self.channels = 1 |
| self.bidirectional = bidirectional |
| self.residual_long_conv = residual_long_conv |
| self.NUM_PROJECTIONS = 3 |
|
|
| print('-- Bidirectional:', self.bidirectional) |
| print("-- Using Long Conv Residual:", self.residual_long_conv) |
| print('-- Hyena w:', hyena_w) |
| print('-- Hyena w mod:', hyena_w_mod) |
| print(f"-- Hyena filter order: {hyena_filter_order}") |
| print(f"-- Hyena filter dropout: {hyena_filter_dropout}") |
| print(f"-- Hyena filter wd: {hyena_wd}") |
| print(f"-- Hyena filter emb dim: {hyena_emb_dim}") |
| print(f"-- Hyena filter lr: {hyena_kernel_lr}") |
| print(f"-- Hyena filter lr pos emb: {hyena_lr_pos_emb}") |
|
|
| self.filter_fn = HyenaFilter( |
| self.d_model, |
| order=hyena_filter_order, |
| seq_len=self.l_max, |
| dropout=hyena_filter_dropout, |
| bidirectional=self.bidirectional, |
| lr=hyena_kernel_lr, |
| lr_pos_emb=hyena_lr_pos_emb, |
| w=hyena_w, |
| w_mod=hyena_w_mod, |
| wd=hyena_wd, |
| emb_dim=hyena_emb_dim, |
| ) |
| |
| if self.residual_long_conv: |
| self.filter_fn2 = HyenaFilter( |
| self.d_model, |
| order=hyena_filter_order, |
| seq_len=self.l_max, |
| dropout=hyena_filter_dropout, |
| bidirectional=self.bidirectional, |
| lr=hyena_kernel_lr, |
| lr_pos_emb=hyena_lr_pos_emb, |
| w=hyena_w, |
| w_mod=hyena_w_mod, |
| wd=hyena_wd, |
| emb_dim=hyena_emb_dim, |
| ) |
| |
| |
| self.in_linear = nn.Linear(d_model, 3 * d_model) |
| self.out_linear = nn.Linear(d_model, d_model) |
| self.hyena_training_additions = hyena_training_additions |
| if self.hyena_training_additions: |
| self.act = nn.Identity() |
| self.drop = nn.Dropout(dropout) |
| self.layernorm = nn.LayerNorm(d_model) |
| |
| |
| total_width = self.d_model * self.NUM_PROJECTIONS |
| self.short_filter = nn.Conv1d( |
| in_channels=total_width, |
| out_channels=total_width, |
| kernel_size=3, |
| groups=total_width, |
| padding=2, |
| ) |
|
|
|
|
| def forward(self, u, **kwargs): |
| |
| if self.hyena_training_additions: |
| u = self.layernorm(u) |
| L = u.size(-2) |
|
|
| |
| u_orig = u |
| u = self.in_linear(u) |
| u = rearrange(u, "b l d -> b d l") |
| |
| |
| uc = self.short_filter(u)[..., :L] |
|
|
| x1, x2, v = uc.split(self.d_model, dim=1) |
| |
| v = v * x1 |
| if self.hyena_training_additions: |
| v = self.drop(v) |
|
|
| k = self.filter_fn.filter(L, device=u.device) |
| k = rearrange(k, "c l d -> c d l")[0] |
|
|
| if self.bidirectional: |
| k_rev = self.filter_fn.filter_rev(L, device=u.device) |
| k_rev = rearrange(k_rev, "c l d -> c d l")[0] |
| else: |
| k_rev = None |
|
|
| y = self.filter_fn(v, L, k_fwd=k, k_rev=k_rev, bias= self.filter_fn.bias[None, :, None]) |
|
|
| if self.residual_long_conv: |
| k2 = self.filter_fn2.filter(L, device=u.device) |
| k2 = rearrange(k2, "c l d -> c d l")[0] |
|
|
| if self.bidirectional: |
| k2_rev = self.filter_fn2.filter_rev(L, device=u.device) |
| k2_rev = rearrange(k2_rev, "c l d -> c d l")[0] |
| else: |
| k2_rev = None |
|
|
| yu = self.filter_fn2(u_orig.transpose(-1, -2), L, k_fwd=k2, k_rev=k2_rev, bias= self.filter_fn2.bias[None, :, None]) |
| |
| |
| y = y * x2 |
|
|
| if self.residual_long_conv: |
| y = y + yu |
|
|
| y = y.transpose(-1, -2) |
| if self.hyena_training_additions: |
| y = self.drop(self.act(y)) |
| y = self.out_linear(y) |
|
|
| return y, None |
|
|
| |