| import torch |
| import torch.nn as nn |
|
|
| from .modules import STU |
| from .modules import MLP |
| from .modules import Attention |
| try: |
| from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP |
| triton_mlp = True |
| except ImportError as e: |
| print( |
| f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead." |
| ) |
| triton_mlp = False |
|
|
| try: |
| from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm |
| triton_norm = True |
| except ImportError as e: |
| print( |
| f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation." |
| ) |
| from torch.nn import RMSNorm |
| triton_norm = False |
|
|
|
|
| class STULayer(nn.Module): |
| def __init__(self, config, phi, n): |
| super(STULayer, self).__init__() |
| if isinstance(config.torch_dtype, str): |
| torch_dtype = getattr(torch, config.torch_dtype) |
| else: |
| torch_dtype = config.torch_dtype |
| self.stu_norm = ( |
| TritonNorm(config.n_embd) |
| if triton_norm |
| else RMSNorm(config.n_embd, dtype=torch_dtype) |
| ) |
| self.stu = STU(config, phi, n) |
| self.stu = self.stu.to(dtype=torch_dtype) |
| self.mlp_norm = ( |
| TritonNorm(config.n_embd) |
| if triton_norm |
| else RMSNorm(config.n_embd, dtype=torch_dtype) |
| ) |
| self.mlp = ( |
| TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype) |
| ) |
|
|
| |
| self.stu_norm = self.stu_norm.to(dtype=torch_dtype) |
| self.mlp = self.mlp.to(dtype=torch_dtype) |
| self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| |
| x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) |
| x_stu = self.stu(x_normed).to(dtype=x.dtype) |
| x = x + x_stu |
| |
| |
| x_normed_mlp = self.mlp_norm(x).to(dtype=self.mlp.gate_proj.weight.dtype) |
| x_mlp = self.mlp(x_normed_mlp).to(dtype=x.dtype) |
| x = x + x_mlp |
| |
| return x |
|
|
| class AttentionLayer(nn.Module): |
| def __init__(self, config) -> None: |
| super(AttentionLayer, self).__init__() |
| if isinstance(config.torch_dtype, str): |
| torch_dtype = getattr(torch, config.torch_dtype) |
| else: |
| torch_dtype = config.torch_dtype |
| self.attn_norm = ( |
| TritonNorm(config.n_embd) |
| if triton_norm |
| else RMSNorm(config.n_embd, dtype=torch_dtype) |
| ) |
| self.attn = Attention(config) |
| self.attn = self.attn.to(dtype=torch_dtype) |
| self.mlp_norm = ( |
| TritonNorm(config.n_embd) |
| if triton_norm |
| else RMSNorm(config.n_embd, dtype=torch_dtype) |
| ) |
| self.mlp = ( |
| TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype) |
| ) |
| self.mlp = self.mlp.to(dtype=torch_dtype) |
|
|
| |
| self.attn_norm = self.attn_norm.to(dtype=torch_dtype) |
| self.mlp = self.mlp.to(dtype=torch_dtype) |
| self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(self.attn_norm(x)) |
| x = x + self.mlp(self.mlp_norm(x)) |
| return x |
|
|