""" Tests for tensor decomposition layers. Verifies: - Correct output shapes - Rank truncation preserves structure - Compression ratio computation - Gradient flow """ import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import torch import pytest from src.tensor_layers import TTLinear, TTFeedForward, factorize_dim class TestFactorizeDim: def test_power_of_two(self): factors = factorize_dim(64) assert all(f >= 2 for f in factors), f"Got dead factor in {factors}" def test_prime(self): factors = factorize_dim(7) # Some factors may be 1 for primes (unavoidable) product = 1 for f in factors: product *= f assert product == 7, f"Prime 7 product: {factors} = {product}" def test_one(self): factors = factorize_dim(1) assert factors == (1,) def test_large(self): for dim in [128, 256, 512, 1024]: factors = factorize_dim(dim) product = 1 for f in factors: product *= f assert product == dim, f"Product mismatch: {factors} = {product} != {dim}" class TestTTLinear: def test_output_shape(self): layer = TTLinear(64, 128, rank=8) x = torch.randn(4, 64) y = layer(x) assert y.shape == (4, 128) def test_batched(self): layer = TTLinear(64, 128, rank=8) x = torch.randn(3, 5, 64) y = layer(x) assert y.shape == (3, 5, 128) def test_gradient_flow(self): layer = TTLinear(64, 128, rank=8) x = torch.randn(4, 64, requires_grad=False) y = layer(x) loss = y.sum() loss.backward() for core in layer.cores: assert core.grad is not None assert not torch.isnan(core.grad).any() def test_set_rank_smaller(self): layer = TTLinear(64, 128, rank=8) x = torch.randn(4, 64) y_before = layer(x) layer.set_rank(4) y_after = layer(x) assert y_after.shape == y_before.shape assert layer.rank == 4 def test_set_rank_larger(self): layer = TTLinear(64, 128, rank=4) layer.set_rank(8) assert layer.rank == 8 def test_compression_ratio(self): layer = TTLinear(128, 256, rank=8) assert layer.compression_ratio > 1.0 def test_bias(self): layer = TTLinear(64, 128, rank=8, bias=True) assert layer.bias is not None layer_nb = TTLinear(64, 128, rank=8, bias=False) assert layer_nb.bias is None class TestTTFeedForward: def test_output_shape(self): ffn = TTFeedForward(128, ff_multiplier=4, rank=8) x = torch.randn(4, 128) y = ffn(x) assert y.shape == (4, 128) def test_set_rank(self): ffn = TTFeedForward(128, ff_multiplier=4, rank=8) x = torch.randn(4, 128) y_before = ffn(x) ffn.set_rank(4) y_after = ffn(x) assert y_after.shape == y_before.shape def test_total_params(self): ffn = TTFeedForward(128, ff_multiplier=4, rank=8) params = ffn.total_params assert params > 0 # Should be fewer than dense equivalent dense = 128 * 512 + 512 * 128 # up + down assert params < dense