| """Tests for LTI injection in ACT loops.""" |
| import torch |
| from arbitor.components import LTIInjection, ByteHead |
| from arbitor.decoders import VideoHead |
|
|
|
|
| def test_lti_basic_properties(): |
| lti = LTIInjection(64) |
| h = torch.randn(2, 10, 64) |
| e = torch.randn(2, 10, 64) |
| t = torch.randn(2, 10, 64) |
| out = lti(h, e, t) |
| assert out.shape == h.shape |
| assert torch.isfinite(out).all() |
|
|
|
|
| def test_lti_spectral_radius(): |
| lti = LTIInjection(64) |
| A = lti.get_A() |
| assert (A > 0).all() |
| assert (A < 1).all() |
|
|
|
|
| def test_lti_learnable_params(): |
| lti = LTIInjection(128) |
| assert lti.log_A.shape == (128,) |
| assert lti.log_dt.shape == (1,) |
| assert lti.B.shape == (128,) |
| assert sum(p.numel() for p in lti.parameters()) == 128 + 1 + 128 |
|
|
|
|
| def test_lti_state_decay(): |
| lti = LTIInjection(8) |
| h = torch.ones(1, 1, 8) * 100.0 |
| e = torch.zeros(1, 1, 8) |
| t = torch.zeros(1, 1, 8) |
| out = lti(h, e, t) |
| assert (out.abs() < 50).all() |
|
|
|
|
| def test_lti_initial_state_small(): |
| lti = LTIInjection(8) |
| h = torch.zeros(1, 1, 8) |
| e = torch.ones(1, 1, 8) * 5.0 |
| t = torch.zeros(1, 1, 8) |
| out = lti(h, e, t) |
| assert (out > 0).all() |
| assert (out < 5).all() |
|
|
|
|
| def test_bytehead_lti_integration(): |
| bh = ByteHead() |
| x = torch.randn(2, 10, 8192) |
| logits = bh(x) |
| assert logits.shape[-1] == 288 |
| assert bh.lti is not None |
| assert isinstance(bh.lti, LTIInjection) |
|
|
|
|
| def test_bytehead_no_act(): |
| bh_single = ByteHead(act_max_iters=1) |
| assert bh_single.lti is None |
| x = torch.randn(1, 5, 8192) |
| logits = bh_single(x) |
| assert logits.shape[-1] == 288 |
|
|