File size: 3,312 Bytes
bcadbf4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | """
Tests for tensor decomposition layers.
Verifies:
- Correct output shapes
- Rank truncation preserves structure
- Compression ratio computation
- Gradient flow
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import pytest
from src.tensor_layers import TTLinear, TTFeedForward, factorize_dim
class TestFactorizeDim:
def test_power_of_two(self):
factors = factorize_dim(64)
assert all(f >= 2 for f in factors), f"Got dead factor in {factors}"
def test_prime(self):
factors = factorize_dim(7)
# Some factors may be 1 for primes (unavoidable)
product = 1
for f in factors:
product *= f
assert product == 7, f"Prime 7 product: {factors} = {product}"
def test_one(self):
factors = factorize_dim(1)
assert factors == (1,)
def test_large(self):
for dim in [128, 256, 512, 1024]:
factors = factorize_dim(dim)
product = 1
for f in factors:
product *= f
assert product == dim, f"Product mismatch: {factors} = {product} != {dim}"
class TestTTLinear:
def test_output_shape(self):
layer = TTLinear(64, 128, rank=8)
x = torch.randn(4, 64)
y = layer(x)
assert y.shape == (4, 128)
def test_batched(self):
layer = TTLinear(64, 128, rank=8)
x = torch.randn(3, 5, 64)
y = layer(x)
assert y.shape == (3, 5, 128)
def test_gradient_flow(self):
layer = TTLinear(64, 128, rank=8)
x = torch.randn(4, 64, requires_grad=False)
y = layer(x)
loss = y.sum()
loss.backward()
for core in layer.cores:
assert core.grad is not None
assert not torch.isnan(core.grad).any()
def test_set_rank_smaller(self):
layer = TTLinear(64, 128, rank=8)
x = torch.randn(4, 64)
y_before = layer(x)
layer.set_rank(4)
y_after = layer(x)
assert y_after.shape == y_before.shape
assert layer.rank == 4
def test_set_rank_larger(self):
layer = TTLinear(64, 128, rank=4)
layer.set_rank(8)
assert layer.rank == 8
def test_compression_ratio(self):
layer = TTLinear(128, 256, rank=8)
assert layer.compression_ratio > 1.0
def test_bias(self):
layer = TTLinear(64, 128, rank=8, bias=True)
assert layer.bias is not None
layer_nb = TTLinear(64, 128, rank=8, bias=False)
assert layer_nb.bias is None
class TestTTFeedForward:
def test_output_shape(self):
ffn = TTFeedForward(128, ff_multiplier=4, rank=8)
x = torch.randn(4, 128)
y = ffn(x)
assert y.shape == (4, 128)
def test_set_rank(self):
ffn = TTFeedForward(128, ff_multiplier=4, rank=8)
x = torch.randn(4, 128)
y_before = ffn(x)
ffn.set_rank(4)
y_after = ffn(x)
assert y_after.shape == y_before.shape
def test_total_params(self):
ffn = TTFeedForward(128, ff_multiplier=4, rank=8)
params = ffn.total_params
assert params > 0
# Should be fewer than dense equivalent
dense = 128 * 512 + 512 * 128 # up + down
assert params < dense
|