File size: 1,624 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Tests for LTI injection in ACT loops."""
import torch
from arbitor.components import LTIInjection, ByteHead
from arbitor.decoders import VideoHead


def test_lti_basic_properties():
    lti = LTIInjection(64)
    h = torch.randn(2, 10, 64)
    e = torch.randn(2, 10, 64)
    t = torch.randn(2, 10, 64)
    out = lti(h, e, t)
    assert out.shape == h.shape
    assert torch.isfinite(out).all()


def test_lti_spectral_radius():
    lti = LTIInjection(64)
    A = lti.get_A()
    assert (A > 0).all()
    assert (A < 1).all()


def test_lti_learnable_params():
    lti = LTIInjection(128)
    assert lti.log_A.shape == (128,)
    assert lti.log_dt.shape == (1,)
    assert lti.B.shape == (128,)
    assert sum(p.numel() for p in lti.parameters()) == 128 + 1 + 128


def test_lti_state_decay():
    lti = LTIInjection(8)
    h = torch.ones(1, 1, 8) * 100.0
    e = torch.zeros(1, 1, 8)
    t = torch.zeros(1, 1, 8)
    out = lti(h, e, t)
    assert (out.abs() < 50).all()


def test_lti_initial_state_small():
    lti = LTIInjection(8)
    h = torch.zeros(1, 1, 8)
    e = torch.ones(1, 1, 8) * 5.0
    t = torch.zeros(1, 1, 8)
    out = lti(h, e, t)
    assert (out > 0).all()
    assert (out < 5).all()


def test_bytehead_lti_integration():
    bh = ByteHead()
    x = torch.randn(2, 10, 8192)
    logits = bh(x)
    assert logits.shape[-1] == 288
    assert bh.lti is not None
    assert isinstance(bh.lti, LTIInjection)


def test_bytehead_no_act():
    bh_single = ByteHead(act_max_iters=1)
    assert bh_single.lti is None
    x = torch.randn(1, 5, 8192)
    logits = bh_single(x)
    assert logits.shape[-1] == 288