File size: 3,312 Bytes
bcadbf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Tests for tensor decomposition layers.

Verifies:
  - Correct output shapes
  - Rank truncation preserves structure
  - Compression ratio computation
  - Gradient flow
"""

import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import pytest
from src.tensor_layers import TTLinear, TTFeedForward, factorize_dim


class TestFactorizeDim:
    def test_power_of_two(self):
        factors = factorize_dim(64)
        assert all(f >= 2 for f in factors), f"Got dead factor in {factors}"

    def test_prime(self):
        factors = factorize_dim(7)
        # Some factors may be 1 for primes (unavoidable)
        product = 1
        for f in factors:
            product *= f
        assert product == 7, f"Prime 7 product: {factors} = {product}"

    def test_one(self):
        factors = factorize_dim(1)
        assert factors == (1,)

    def test_large(self):
        for dim in [128, 256, 512, 1024]:
            factors = factorize_dim(dim)
            product = 1
            for f in factors:
                product *= f
            assert product == dim, f"Product mismatch: {factors} = {product} != {dim}"


class TestTTLinear:
    def test_output_shape(self):
        layer = TTLinear(64, 128, rank=8)
        x = torch.randn(4, 64)
        y = layer(x)
        assert y.shape == (4, 128)

    def test_batched(self):
        layer = TTLinear(64, 128, rank=8)
        x = torch.randn(3, 5, 64)
        y = layer(x)
        assert y.shape == (3, 5, 128)

    def test_gradient_flow(self):
        layer = TTLinear(64, 128, rank=8)
        x = torch.randn(4, 64, requires_grad=False)
        y = layer(x)
        loss = y.sum()
        loss.backward()
        for core in layer.cores:
            assert core.grad is not None
            assert not torch.isnan(core.grad).any()

    def test_set_rank_smaller(self):
        layer = TTLinear(64, 128, rank=8)
        x = torch.randn(4, 64)
        y_before = layer(x)

        layer.set_rank(4)
        y_after = layer(x)

        assert y_after.shape == y_before.shape
        assert layer.rank == 4

    def test_set_rank_larger(self):
        layer = TTLinear(64, 128, rank=4)
        layer.set_rank(8)
        assert layer.rank == 8

    def test_compression_ratio(self):
        layer = TTLinear(128, 256, rank=8)
        assert layer.compression_ratio > 1.0

    def test_bias(self):
        layer = TTLinear(64, 128, rank=8, bias=True)
        assert layer.bias is not None

        layer_nb = TTLinear(64, 128, rank=8, bias=False)
        assert layer_nb.bias is None


class TestTTFeedForward:
    def test_output_shape(self):
        ffn = TTFeedForward(128, ff_multiplier=4, rank=8)
        x = torch.randn(4, 128)
        y = ffn(x)
        assert y.shape == (4, 128)

    def test_set_rank(self):
        ffn = TTFeedForward(128, ff_multiplier=4, rank=8)
        x = torch.randn(4, 128)
        y_before = ffn(x)

        ffn.set_rank(4)
        y_after = ffn(x)

        assert y_after.shape == y_before.shape

    def test_total_params(self):
        ffn = TTFeedForward(128, ff_multiplier=4, rank=8)
        params = ffn.total_params
        assert params > 0
        # Should be fewer than dense equivalent
        dense = 128 * 512 + 512 * 128  # up + down
        assert params < dense