| import torch |
| import torch.nn as nn |
|
|
| from .convolve import convolve, flash_convolve |
|
|
| try: |
| from flashfftconv import FlashFFTConv |
|
|
| flash_fft_available = True |
| except ImportError as e: |
| print( |
| f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation." |
| ) |
| flash_fft_available = False |
|
|
|
|
| class STU(nn.Module): |
| def __init__(self, config, phi, n) -> None: |
| super(STU, self).__init__() |
| self.config = config |
| if isinstance(config.torch_dtype, str): |
| torch_dtype = getattr(torch, config.torch_dtype) |
| else: |
| torch_dtype = config.torch_dtype |
| self.phi = phi.to(device=config.device, dtype=torch_dtype) |
| self.n = n |
| self.K = config.num_eigh |
| self.d_in = config.n_embd |
| self.d_out = config.n_embd |
| self.use_hankel_L = config.use_hankel_L |
| self.use_approx = config.use_approx |
| self.flash_fft = ( |
| FlashFFTConv(self.n, dtype=torch.bfloat16) |
| if config.use_flash_fft and flash_fft_available |
| else None |
| ) |
| if self.use_approx: |
| self.M_inputs = nn.Parameter( |
| torch.empty(self.d_in, self.d_out, dtype=torch_dtype) |
| ) |
| self.M_filters = nn.Parameter( |
| torch.empty(self.K, self.d_in, dtype=torch_dtype) |
| ) |
| else: |
| self.M_phi_plus = nn.Parameter( |
| torch.empty(self.K, self.d_in, self.d_out, dtype=torch_dtype) |
| ) |
| if not self.use_hankel_L: |
| self.M_phi_minus = nn.Parameter( |
| torch.empty(self.K, self.d_in, self.d_out, dtype=torch_dtype) |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| dtype = self.M_inputs.dtype |
| x = x.to(dtype=dtype) |
| if self.use_approx: |
| |
| x_proj = x @ self.M_inputs |
| phi_proj = self.phi @ self.M_filters |
| x_proj = x_proj.to(dtype=dtype) |
| phi_proj = phi_proj.to(dtype=dtype) |
| if self.flash_fft: |
| spectral_plus, spectral_minus = flash_convolve( |
| x_proj, phi_proj, self.flash_fft, self.use_approx |
| ) |
| else: |
| spectral_plus, spectral_minus = convolve( |
| x_proj, phi_proj, self.n, self.use_approx |
| ) |
| else: |
| |
| if self.flash_fft: |
| U_plus, U_minus = flash_convolve( |
| x, self.phi, self.flash_fft, self.use_approx |
| ) |
| else: |
| U_plus, U_minus = convolve(x, self.phi, self.n, self.use_approx) |
| |
| spectral_plus = torch.tensordot( |
| U_plus, self.M_phi_plus, dims=([2, 3], [0, 1]) |
| ) |
| if not self.use_hankel_L: |
| spectral_minus = torch.tensordot( |
| U_minus, self.M_phi_minus, dims=([2, 3], [0, 1]) |
| ) |
|
|
| return spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus |
|
|