| """ |
| Tests for tensor decomposition layers. |
| |
| Verifies: |
| - Correct output shapes |
| - Rank truncation preserves structure |
| - Compression ratio computation |
| - Gradient flow |
| """ |
|
|
| import sys |
| import os |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
|
|
| import torch |
| import pytest |
| from src.tensor_layers import TTLinear, TTFeedForward, factorize_dim |
|
|
|
|
| class TestFactorizeDim: |
| def test_power_of_two(self): |
| factors = factorize_dim(64) |
| assert all(f >= 2 for f in factors), f"Got dead factor in {factors}" |
|
|
| def test_prime(self): |
| factors = factorize_dim(7) |
| |
| product = 1 |
| for f in factors: |
| product *= f |
| assert product == 7, f"Prime 7 product: {factors} = {product}" |
|
|
| def test_one(self): |
| factors = factorize_dim(1) |
| assert factors == (1,) |
|
|
| def test_large(self): |
| for dim in [128, 256, 512, 1024]: |
| factors = factorize_dim(dim) |
| product = 1 |
| for f in factors: |
| product *= f |
| assert product == dim, f"Product mismatch: {factors} = {product} != {dim}" |
|
|
|
|
| class TestTTLinear: |
| def test_output_shape(self): |
| layer = TTLinear(64, 128, rank=8) |
| x = torch.randn(4, 64) |
| y = layer(x) |
| assert y.shape == (4, 128) |
|
|
| def test_batched(self): |
| layer = TTLinear(64, 128, rank=8) |
| x = torch.randn(3, 5, 64) |
| y = layer(x) |
| assert y.shape == (3, 5, 128) |
|
|
| def test_gradient_flow(self): |
| layer = TTLinear(64, 128, rank=8) |
| x = torch.randn(4, 64, requires_grad=False) |
| y = layer(x) |
| loss = y.sum() |
| loss.backward() |
| for core in layer.cores: |
| assert core.grad is not None |
| assert not torch.isnan(core.grad).any() |
|
|
| def test_set_rank_smaller(self): |
| layer = TTLinear(64, 128, rank=8) |
| x = torch.randn(4, 64) |
| y_before = layer(x) |
|
|
| layer.set_rank(4) |
| y_after = layer(x) |
|
|
| assert y_after.shape == y_before.shape |
| assert layer.rank == 4 |
|
|
| def test_set_rank_larger(self): |
| layer = TTLinear(64, 128, rank=4) |
| layer.set_rank(8) |
| assert layer.rank == 8 |
|
|
| def test_compression_ratio(self): |
| layer = TTLinear(128, 256, rank=8) |
| assert layer.compression_ratio > 1.0 |
|
|
| def test_bias(self): |
| layer = TTLinear(64, 128, rank=8, bias=True) |
| assert layer.bias is not None |
|
|
| layer_nb = TTLinear(64, 128, rank=8, bias=False) |
| assert layer_nb.bias is None |
|
|
|
|
| class TestTTFeedForward: |
| def test_output_shape(self): |
| ffn = TTFeedForward(128, ff_multiplier=4, rank=8) |
| x = torch.randn(4, 128) |
| y = ffn(x) |
| assert y.shape == (4, 128) |
|
|
| def test_set_rank(self): |
| ffn = TTFeedForward(128, ff_multiplier=4, rank=8) |
| x = torch.randn(4, 128) |
| y_before = ffn(x) |
|
|
| ffn.set_rank(4) |
| y_after = ffn(x) |
|
|
| assert y_after.shape == y_before.shape |
|
|
| def test_total_params(self): |
| ffn = TTFeedForward(128, ff_multiplier=4, rank=8) |
| params = ffn.total_params |
| assert params > 0 |
| |
| dense = 128 * 512 + 512 * 128 |
| assert params < dense |
|
|