| """ |
| Tensor-Train decomposed linear layers. |
| |
| v3 improvements: |
| - SVD-based rank truncation (preserves dominant singular vectors) |
| - No dead padding cores (factorize_dim ensures all factors ≥ 2) |
| - torch.no_grad() on set_rank |
| - Built-in compression statistics |
| - Budget-aware: auto-selects minimum rank meeting constraints |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from typing import Tuple, Optional |
|
|
|
|
| def factorize_dim(dim: int, max_factors: int = 4) -> Tuple[int, ...]: |
| """ |
| Factorize a dimension for TT decomposition. |
| Ensures all factors >= 2 to avoid dead cores. |
| """ |
| if dim <= 1: |
| return (1,) |
| factors = [] |
| remaining = dim |
| for p in [2, 2, 3, 2, 5, 2, 3, 7]: |
| while remaining % p == 0 and len(factors) < max_factors - 1: |
| factors.append(p) |
| remaining //= p |
| if remaining == 1: |
| break |
| if remaining > 1 and len(factors) < max_factors: |
| factors.append(remaining) |
| while len(factors) < 2: |
| val = factors[0] if factors else dim |
| root = int(math.isqrt(val)) |
| for d in range(root, 1, -1): |
| if val % d == 0: |
| factors = [d, val // d] |
| break |
| else: |
| factors = [1, val] |
| return tuple(factors[:max_factors]) |
|
|
|
|
| def compute_tt_params(in_features: int, out_features: int, |
| in_shape: Tuple[int, ...], rank: int) -> int: |
| """Compute number of parameters in a TT layer.""" |
| d = len(in_shape) |
| params = 0 |
| |
| params += out_features // math.prod(in_shape[1:]) * in_shape[0] * rank if d > 0 else 0 |
| |
| for k in range(1, d - 1): |
| params += rank * rank * in_shape[k] * in_shape[k] |
| |
| if d > 1: |
| params += rank * in_shape[-1] * in_shape[-1] |
| return params |
|
|
|
|
| class TTLinear(nn.Module): |
| """ |
| Tensor-Train decomposed linear layer. |
| |
| Replaces a dense weight matrix W ∈ R^{out×in} with d TT-cores. |
| Core k has shape (r_k, out_k, in_k, r_{k+1}) with r_0 = r_d = 1. |
| |
| Parameters |
| ---------- |
| in_features : int |
| Input dimension. |
| out_features : int |
| Output dimension. |
| rank : int |
| TT-rank (bond dimension). Lower → more compression. |
| bias : bool |
| Include bias term. |
| """ |
|
|
| def __init__(self, in_features: int, out_features: int, |
| rank: int = 8, bias: bool = True): |
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.rank = rank |
|
|
| |
| in_factors = factorize_dim(in_features) |
| out_factors = factorize_dim(out_features) |
| self.ndim = max(len(in_factors), len(out_factors)) |
|
|
| |
| in_factors = list(in_factors) |
| out_factors = list(out_factors) |
| while len(in_factors) < self.ndim: |
| in_factors.append(1) |
| while len(out_factors) < self.ndim: |
| out_factors.append(1) |
| self.in_shape = tuple(in_factors) |
| self.out_shape = tuple(out_factors) |
|
|
| |
| self.cores = nn.ParameterList() |
| for k in range(self.ndim): |
| r_left = 1 if k == 0 else rank |
| r_right = 1 if k == self.ndim - 1 else rank |
| core = torch.empty(r_left, out_factors[k], in_factors[k], r_right) |
| fan = max(1, r_left * in_factors[k] + r_right * out_factors[k]) |
| bound = math.sqrt(6.0 / fan) |
| nn.init.uniform_(core, -bound, bound) |
| self.cores.append(core) |
|
|
| self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None |
|
|
| |
| tt_params = sum(c.numel() for c in self.cores) |
| if self.bias is not None: |
| tt_params += self.bias.numel() |
| dense_params = in_features * out_features |
| self.compression_ratio = dense_params / max(tt_params, 1) |
| self._tt_params = tt_params |
| self._dense_params = dense_params |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass: sequential TT contraction. |
| |
| Args: |
| x: (*batch_dims, in_features) |
| Returns: |
| (*batch_dims, out_features) |
| """ |
| batch_shape = x.shape[:-1] |
| B = math.prod(batch_shape) if batch_shape else 1 |
| x = x.reshape(B, self.in_features) |
| state = x.reshape(B, *self.in_shape) |
|
|
| for k in range(self.ndim): |
| core = self.cores[k] |
| r_k, o_k, i_k, r_kp1 = core.shape |
|
|
| if k == 0: |
| rest = math.prod(self.in_shape[1:]) if self.ndim > 1 else 1 |
| s = state.reshape(B, i_k, rest) |
| cm = core.squeeze(0).permute(1, 0, 2).reshape(i_k, o_k * r_kp1) |
| s = torch.bmm(s.transpose(1, 2), cm.unsqueeze(0).expand(B, -1, -1)) |
| s = s.reshape(B, rest, o_k, r_kp1).permute(0, 3, 2, 1) |
| state = s.reshape(B, r_kp1, -1) |
|
|
| elif k == self.ndim - 1: |
| prev_os = math.prod(self.out_shape[:k]) if k > 0 else 1 |
| s = state.reshape(B, r_k, prev_os, i_k) |
| cm = core.squeeze(-1) |
| s = torch.einsum('brpi,roi->bpo', s, cm) |
| state = s.reshape(B, prev_os * o_k) |
|
|
| else: |
| prev_os = math.prod(self.out_shape[:k]) if k > 0 else 1 |
| rest_in = math.prod(self.in_shape[k + 1:]) |
| s = state.reshape(B, r_k, prev_os * i_k * rest_in) |
| s = s.reshape(B, r_k, prev_os, i_k, rest_in) |
| s = torch.einsum('brpix,roiq->bpoqx', s, core) |
| s = s.permute(0, 3, 1, 2, 4) |
| state = s.reshape(B, r_kp1, prev_os * o_k * rest_in) |
|
|
| out = state.reshape(B, self.out_features) |
| if self.bias is not None: |
| out = out + self.bias |
| return out.reshape(*batch_shape, self.out_features) |
|
|
| @torch.no_grad() |
| def set_rank(self, new_rank: int): |
| """ |
| SVD-based TT-rank truncation. |
| |
| Strategy: For each pair of adjacent cores, merge into a supercore, |
| compute SVD, and keep top `new_rank` singular values. |
| Then split back into two cores at the new rank. |
| |
| For single-core edge case (ndim=1): just truncate the SVD of the sole core. |
| """ |
| if new_rank == self.rank: |
| return |
| new_rank = max(1, new_rank) |
|
|
| if self.ndim == 1: |
| |
| old = self.cores[0].data |
| mat = old.reshape(old.shape[1], old.shape[2]) |
| U, S, Vt = torch.linalg.svd(mat, full_matrices=False) |
| tr = min(new_rank, S.shape[0]) |
| self.cores[0] = nn.Parameter( |
| ((U[:, :tr] * S[:tr]) @ Vt[:tr, :]).reshape(1, old.shape[1], old.shape[2], 1) |
| ) |
| self.rank = new_rank |
| else: |
| |
| |
| for k in range(self.ndim - 1): |
| core_a = self.cores[k].data |
| core_b = self.cores[k + 1].data |
|
|
| r_k, o_a, i_a, r_mid = core_a.shape |
| r_mid2, o_b, i_b, r_k2 = core_b.shape |
| assert r_mid == r_mid2, f"Rank mismatch: {r_mid} != {r_mid2}" |
|
|
| |
| |
| |
| |
| mat_a = core_a.reshape(-1, r_mid) |
| mat_b = core_b.reshape(r_mid, -1) |
|
|
| |
| combined = mat_a @ mat_b |
| U, S, Vt = torch.linalg.svd(combined, full_matrices=False) |
| tr = min(new_rank, S.shape[0]) |
|
|
| |
| U_tr = U[:, :tr] |
| Vt_tr = Vt[:tr, :] |
| S_sqrt = torch.sqrt(S[:tr] + 1e-10) |
|
|
| new_a = (U_tr * S_sqrt).reshape(r_k, o_a, i_a, tr) |
| new_b = (S_sqrt.unsqueeze(-1) * Vt_tr).reshape(tr, o_b, i_b, r_k2) |
|
|
| self.cores[k].data = new_a |
| self.cores[k + 1].data = new_b |
|
|
| self.rank = new_rank |
|
|
| |
| tt_params = sum(c.numel() for c in self.cores) |
| if self.bias is not None: |
| tt_params += self.bias.numel() |
| self._tt_params = tt_params |
| self.compression_ratio = self._dense_params / max(tt_params, 1) |
|
|
| def flops(self, batch_size: int = 1) -> int: |
| """Estimate FLOPs for this layer.""" |
| |
| avg_dim = (sum(self.in_shape) + sum(self.out_shape)) / (2 * self.ndim) |
| return int(2 * self.rank**2 * self.ndim * avg_dim * batch_size) |
|
|
| def extra_repr(self) -> str: |
| return (f"in_shape={self.in_shape}, out_shape={self.out_shape}, " |
| f"rank={self.rank}, compression={self.compression_ratio:.1f}x") |
|
|
|
|
| class TTFeedForward(nn.Module): |
| """ |
| Tensor-Train Feed-Forward Network. |
| |
| Replaces standard FFN (Linear↑→GELU→Linear↓) with TT-decomposed layers. |
| |
| Parameters |
| ---------- |
| hidden_dim : int |
| Hidden dimension. |
| ff_multiplier : int |
| FFN expansion factor (default 4x). |
| rank : int |
| TT-rank. |
| activation : callable |
| Activation function (default GELU). |
| """ |
|
|
| def __init__(self, hidden_dim: int, ff_multiplier: int = 4, |
| rank: int = 8, activation=F.gelu): |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| expanded_dim = hidden_dim * ff_multiplier |
|
|
| self.up_proj = TTLinear(hidden_dim, expanded_dim, rank, bias=True) |
| self.down_proj = TTLinear(expanded_dim, hidden_dim, rank, bias=True) |
| self.activation = activation |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.down_proj(self.activation(self.up_proj(x))) |
|
|
| @torch.no_grad() |
| def set_rank(self, rank: int): |
| self.up_proj.set_rank(rank) |
| self.down_proj.set_rank(rank) |
|
|
| @property |
| def total_params(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|
| def flops(self, batch_size: int = 1) -> int: |
| return self.up_proj.flops(batch_size) + self.down_proj.flops(batch_size) |
|
|