Q-TensorFormer / src /tensor_layers.py
Premchan369's picture
v3.0.0: Source files
b9c4adf verified
raw
history blame
10.9 kB
"""
Tensor-Train decomposed linear layers.
v3 improvements:
- SVD-based rank truncation (preserves dominant singular vectors)
- No dead padding cores (factorize_dim ensures all factors ≥ 2)
- torch.no_grad() on set_rank
- Built-in compression statistics
- Budget-aware: auto-selects minimum rank meeting constraints
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional
def factorize_dim(dim: int, max_factors: int = 4) -> Tuple[int, ...]:
"""
Factorize a dimension for TT decomposition.
Ensures all factors >= 2 to avoid dead cores.
"""
if dim <= 1:
return (1,)
factors = []
remaining = dim
for p in [2, 2, 3, 2, 5, 2, 3, 7]:
while remaining % p == 0 and len(factors) < max_factors - 1:
factors.append(p)
remaining //= p
if remaining == 1:
break
if remaining > 1 and len(factors) < max_factors:
factors.append(remaining)
while len(factors) < 2:
val = factors[0] if factors else dim
root = int(math.isqrt(val))
for d in range(root, 1, -1):
if val % d == 0:
factors = [d, val // d]
break
else:
factors = [1, val]
return tuple(factors[:max_factors])
def compute_tt_params(in_features: int, out_features: int,
in_shape: Tuple[int, ...], rank: int) -> int:
"""Compute number of parameters in a TT layer."""
d = len(in_shape)
params = 0
# First core: (1, out_0, in_0, rank)
params += out_features // math.prod(in_shape[1:]) * in_shape[0] * rank if d > 0 else 0
# Middle cores
for k in range(1, d - 1):
params += rank * rank * in_shape[k] * in_shape[k] # approximate
# Last core
if d > 1:
params += rank * in_shape[-1] * in_shape[-1]
return params
class TTLinear(nn.Module):
"""
Tensor-Train decomposed linear layer.
Replaces a dense weight matrix W ∈ R^{out×in} with d TT-cores.
Core k has shape (r_k, out_k, in_k, r_{k+1}) with r_0 = r_d = 1.
Parameters
----------
in_features : int
Input dimension.
out_features : int
Output dimension.
rank : int
TT-rank (bond dimension). Lower → more compression.
bias : bool
Include bias term.
"""
def __init__(self, in_features: int, out_features: int,
rank: int = 8, bias: bool = True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.rank = rank
# Factorize dimensions
in_factors = factorize_dim(in_features)
out_factors = factorize_dim(out_features)
self.ndim = max(len(in_factors), len(out_factors))
# Pad to same length (minimal padding)
in_factors = list(in_factors)
out_factors = list(out_factors)
while len(in_factors) < self.ndim:
in_factors.append(1)
while len(out_factors) < self.ndim:
out_factors.append(1)
self.in_shape = tuple(in_factors)
self.out_shape = tuple(out_factors)
# Initialize TT cores
self.cores = nn.ParameterList()
for k in range(self.ndim):
r_left = 1 if k == 0 else rank
r_right = 1 if k == self.ndim - 1 else rank
core = torch.empty(r_left, out_factors[k], in_factors[k], r_right)
fan = max(1, r_left * in_factors[k] + r_right * out_factors[k])
bound = math.sqrt(6.0 / fan)
nn.init.uniform_(core, -bound, bound)
self.cores.append(core)
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
# Statistics
tt_params = sum(c.numel() for c in self.cores)
if self.bias is not None:
tt_params += self.bias.numel()
dense_params = in_features * out_features
self.compression_ratio = dense_params / max(tt_params, 1)
self._tt_params = tt_params
self._dense_params = dense_params
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass: sequential TT contraction.
Args:
x: (*batch_dims, in_features)
Returns:
(*batch_dims, out_features)
"""
batch_shape = x.shape[:-1]
B = math.prod(batch_shape) if batch_shape else 1
x = x.reshape(B, self.in_features)
state = x.reshape(B, *self.in_shape)
for k in range(self.ndim):
core = self.cores[k]
r_k, o_k, i_k, r_kp1 = core.shape
if k == 0:
rest = math.prod(self.in_shape[1:]) if self.ndim > 1 else 1
s = state.reshape(B, i_k, rest)
cm = core.squeeze(0).permute(1, 0, 2).reshape(i_k, o_k * r_kp1)
s = torch.bmm(s.transpose(1, 2), cm.unsqueeze(0).expand(B, -1, -1))
s = s.reshape(B, rest, o_k, r_kp1).permute(0, 3, 2, 1)
state = s.reshape(B, r_kp1, -1)
elif k == self.ndim - 1:
prev_os = math.prod(self.out_shape[:k]) if k > 0 else 1
s = state.reshape(B, r_k, prev_os, i_k)
cm = core.squeeze(-1)
s = torch.einsum('brpi,roi->bpo', s, cm)
state = s.reshape(B, prev_os * o_k)
else:
prev_os = math.prod(self.out_shape[:k]) if k > 0 else 1
rest_in = math.prod(self.in_shape[k + 1:])
s = state.reshape(B, r_k, prev_os * i_k * rest_in)
s = s.reshape(B, r_k, prev_os, i_k, rest_in)
s = torch.einsum('brpix,roiq->bpoqx', s, core)
s = s.permute(0, 3, 1, 2, 4)
state = s.reshape(B, r_kp1, prev_os * o_k * rest_in)
out = state.reshape(B, self.out_features)
if self.bias is not None:
out = out + self.bias
return out.reshape(*batch_shape, self.out_features)
@torch.no_grad()
def set_rank(self, new_rank: int):
"""
SVD-based TT-rank truncation.
Strategy: For each pair of adjacent cores, merge into a supercore,
compute SVD, and keep top `new_rank` singular values.
Then split back into two cores at the new rank.
For single-core edge case (ndim=1): just truncate the SVD of the sole core.
"""
if new_rank == self.rank:
return
new_rank = max(1, new_rank)
if self.ndim == 1:
# Single core: just reshape to matrix and SVD-truncate
old = self.cores[0].data # (1, o_0, i_0, 1)
mat = old.reshape(old.shape[1], old.shape[2]) # (o_0, i_0)
U, S, Vt = torch.linalg.svd(mat, full_matrices=False)
tr = min(new_rank, S.shape[0])
self.cores[0] = nn.Parameter(
((U[:, :tr] * S[:tr]) @ Vt[:tr, :]).reshape(1, old.shape[1], old.shape[2], 1)
)
self.rank = new_rank
else:
# Strategy: compress bond between each adjacent core pair
# We treat each bond independently, truncating to new_rank
for k in range(self.ndim - 1):
core_a = self.cores[k].data # (r_k, o_k, i_k, r_{k+1})
core_b = self.cores[k + 1].data # (r_{k+1}, o_{k+1}, i_{k+1}, r_{k+2})
r_k, o_a, i_a, r_mid = core_a.shape
r_mid2, o_b, i_b, r_k2 = core_b.shape
assert r_mid == r_mid2, f"Rank mismatch: {r_mid} != {r_mid2}"
# Merge cores along the bond to contract the middle rank
# core_a: reshape to (r_k * o_a * i_a, r_mid)
# core_b: reshape to (r_mid, o_b * i_b * r_k2)
# Merged: (r_k * o_a * i_a, o_b * i_b * r_k2)
mat_a = core_a.reshape(-1, r_mid) # (r_k*o_a*i_a, r_mid)
mat_b = core_b.reshape(r_mid, -1) # (r_mid, o_b*i_b*r_k2)
# Reduced SVD at the bond
combined = mat_a @ mat_b # (r_k*o_a*i_a, o_b*i_b*r_k2)
U, S, Vt = torch.linalg.svd(combined, full_matrices=False)
tr = min(new_rank, S.shape[0])
# Split back
U_tr = U[:, :tr] # (r_k*o_a*i_a, tr)
Vt_tr = Vt[:tr, :] # (tr, o_b*i_b*r_k2)
S_sqrt = torch.sqrt(S[:tr] + 1e-10) # (tr,)
new_a = (U_tr * S_sqrt).reshape(r_k, o_a, i_a, tr) # (r_k, o_a, i_a, tr)
new_b = (S_sqrt.unsqueeze(-1) * Vt_tr).reshape(tr, o_b, i_b, r_k2) # (tr, o_b, i_b, r_k2)
self.cores[k].data = new_a
self.cores[k + 1].data = new_b
self.rank = new_rank
# Update stats
tt_params = sum(c.numel() for c in self.cores)
if self.bias is not None:
tt_params += self.bias.numel()
self._tt_params = tt_params
self.compression_ratio = self._dense_params / max(tt_params, 1)
def flops(self, batch_size: int = 1) -> int:
"""Estimate FLOPs for this layer."""
# TT contraction: ~2 * rank^2 * ndim * avg(in_k * out_k)
avg_dim = (sum(self.in_shape) + sum(self.out_shape)) / (2 * self.ndim)
return int(2 * self.rank**2 * self.ndim * avg_dim * batch_size)
def extra_repr(self) -> str:
return (f"in_shape={self.in_shape}, out_shape={self.out_shape}, "
f"rank={self.rank}, compression={self.compression_ratio:.1f}x")
class TTFeedForward(nn.Module):
"""
Tensor-Train Feed-Forward Network.
Replaces standard FFN (Linear↑→GELU→Linear↓) with TT-decomposed layers.
Parameters
----------
hidden_dim : int
Hidden dimension.
ff_multiplier : int
FFN expansion factor (default 4x).
rank : int
TT-rank.
activation : callable
Activation function (default GELU).
"""
def __init__(self, hidden_dim: int, ff_multiplier: int = 4,
rank: int = 8, activation=F.gelu):
super().__init__()
self.hidden_dim = hidden_dim
expanded_dim = hidden_dim * ff_multiplier
self.up_proj = TTLinear(hidden_dim, expanded_dim, rank, bias=True)
self.down_proj = TTLinear(expanded_dim, hidden_dim, rank, bias=True)
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.activation(self.up_proj(x)))
@torch.no_grad()
def set_rank(self, rank: int):
self.up_proj.set_rank(rank)
self.down_proj.set_rank(rank)
@property
def total_params(self) -> int:
return sum(p.numel() for p in self.parameters())
def flops(self, batch_size: int = 1) -> int:
return self.up_proj.flops(batch_size) + self.down_proj.flops(batch_size)