| import torch |
| import torch.nn as nn |
|
|
| from .attn import FlexAttention |
| from .modules import MLP |
| from .modules import Attention |
| try: |
| from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP |
| triton_mlp = True |
| except ImportError as e: |
| print( |
| f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead." |
| ) |
| triton_mlp = False |
|
|
| try: |
| from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm |
| triton_norm = True |
| except ImportError as e: |
| print( |
| f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation." |
| ) |
| from torch.nn import RMSNorm |
| triton_norm = False |
|
|
| class AttentionLayer(nn.Module): |
| def __init__(self, config, mask_mod, score_mod=None) -> None: |
| super(AttentionLayer, self).__init__() |
| self.attn_norm = nn.RMSNorm(config.dim) |
| self.attn = FlexAttention( |
| config=config, |
| mask_mod=mask_mod, |
| score_mod=score_mod, |
| ) |
| self.mlp_norm = nn.RMSNorm(config.dim) |
| self.mlp = MLP(config) |
|
|
| def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor: |
| x = x + self.attn(self.attn_norm(x), freqs_cis=freqs_cis) |
| x = x + self.mlp(self.mlp_norm(x)) |
| return x |
|
|