| import math |
|
|
| import numpy as np |
| import torch |
|
|
| from .utils import logger |
| from .utils import get_hankel |
|
|
| def get_spectral_filters( |
| seq_len: int, |
| K: int, |
| use_hankel_L: bool = False, |
| device: torch.device = None, |
| dtype: torch.dtype = torch.bfloat16, |
| ) -> torch.Tensor: |
| |
| Z = get_hankel(seq_len, use_hankel_L, device=device, dtype=dtype) |
|
|
| |
| Z_float32 = Z.to(torch.float32) |
|
|
| |
| sigma, phi = torch.linalg.eigh(Z_float32) |
|
|
| |
| sigma = sigma.to(dtype=dtype) |
| phi = phi.to(dtype=dtype) |
|
|
| |
| sigma_k, phi_k = sigma[-K:], phi[:, -K:] |
|
|
| |
| phi_k = phi_k * sigma_k ** 0.25 |
|
|
| |
| filters = phi_k.to(device=device, dtype=dtype) |
|
|
| return filters |
|
|
|
|
| def compute_dimensions(n: int) -> tuple[int, int, int]: |
| if n <= 2: |
| raise ValueError("n must be greater than 2") |
|
|
| T_prime = (math.ceil(math.sqrt(n - 2)))**2 + 2 |
| sqrt_T_prime = math.ceil(math.sqrt(T_prime - 2)) |
| k_max = sqrt_T_prime |
| return T_prime, sqrt_T_prime, k_max |
|
|
| def get_tensorized_spectral_filters_explicit(n: int, k: int, device: torch.device) -> torch.Tensor: |
| T_prime, sqrt_T_prime, k_max = compute_dimensions(n) |
| k = min(k, k_max) |
|
|
| Z = get_hankel(sqrt_T_prime).to(device) |
| sigma, phi = torch.linalg.eigh(Z) |
| sigma_k = sigma[-k:] |
| phi_k = phi[:, -k:] |
|
|
| result = torch.zeros(sqrt_T_prime * sqrt_T_prime, device=device) |
| |
| for i in range(k): |
| for j in range(k): |
| phi_i = phi_k[:, i] * (sigma_k[i] ** 0.25) |
| phi_j = phi_k[:, j] * (sigma_k[j] ** 0.25) |
| kron = torch.kron(phi_i, phi_j) |
| result += kron |
| |
| return result |
|
|
|
|
| def get_tensorized_spectral_filters( |
| n: int = 8192, |
| k: int = 24, |
| use_hankel_L: bool = False, |
| device: torch.device = None, |
| dtype: torch.dtype = torch.bfloat16, |
| ) -> torch.Tensor: |
| """ |
| Compute tensorized spectral filters for given sequence length and filter count. |
| |
| Args: |
| n: Sequence length |
| k: Number of filters |
| use_hankel_L: Hankel_main ⊗ Hankel_L? Default is Hankel_main ⊗ Hankel_main. |
| device: Computation device |
| dtype: Computation dtype |
| """ |
| assert torch.cuda.is_available(), "CUDA is required." |
|
|
| T_prime, sqrt_T_prime, k_max = compute_dimensions(n) |
| k = min(k, k_max) |
|
|
| Z = get_hankel(sqrt_T_prime) |
| sigma, phi = torch.linalg.eigh(Z) |
| phi_i = phi[:, -k:] * sigma[-k:] ** 0.25 |
|
|
| if use_hankel_L: |
| logger.info("Mixing Hankel_L with Hankel_main to generate tensorized filters.") |
| Z_L = get_hankel(sqrt_T_prime, True) |
| sigma_L, phi_L = torch.linalg.eigh(Z_L) |
| phi_j = phi_L[:, -k:] * sigma_L[-k:] ** 0.25 |
| else: |
| phi_j = phi_i |
|
|
| filters = torch.kron(phi_i, phi_j) |
| return filters.to(device=device, dtype=dtype) |
|
|