File size: 3,686 Bytes
6e408ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pytest

torch = pytest.importorskip("torch")

from chimera import (
    Chimera51ForCausalLM, ChimeraTokenizer, load_config, scale_config,
    pack_ternary, unpack_ternary,
)
from chimera.inference import SpanBank
from chimera.moe import MoELayer
from chimera.quantization import BitLinear, ternarize_weight


def cfg():
    c = scale_config(load_config("config.json"), "nano")
    c["vocab_size"] = 512
    c["span_inference"]["enabled"] = False
    return c


def test_pack_unpack_roundtrip():
    q = torch.tensor([[-1, 0, 1, 1, -1, 0, 1, 0, -1]], dtype=torch.int8)
    packed = pack_ternary(q)
    out = unpack_ternary(packed, q.shape[-1], dtype=torch.float32).to(torch.int8)
    assert torch.equal(q, out)


def test_ternarize_weight_basic():
    w = torch.randn(8, 16) * 0.5
    wq, alpha = ternarize_weight(w)
    assert wq.shape == w.shape
    assert alpha.shape == (8,)
    assert (wq.unique().abs() <= 1).all()


def test_bitlinear_forward_backward_and_packed():
    layer = BitLinear(7, 5)
    x = torch.randn(3, 7, requires_grad=True)
    y = layer(x).sum()
    y.backward()
    assert x.grad is not None and torch.isfinite(x.grad).all()
    assert layer.weight.grad is not None
    layer.prepare_for_inference()
    layer.eval()
    with torch.no_grad():
        out = layer(torch.randn(2, 7))
    assert out.shape == (2, 5)


def test_bitlinear_dense_cache_consistency():
    layer = BitLinear(8, 4)
    layer.eval()
    layer.prepare_for_inference()
    x = torch.randn(2, 8)
    with torch.no_grad():
        out1 = layer(x)
        out2 = layer(x)
    assert torch.allclose(out1, out2)


def test_model_forward_loss_and_generate_shape():
    model = Chimera51ForCausalLM(cfg())
    x = torch.randint(0, 512, (2, 8))
    y = torch.randint(0, 512, (2, 8))
    out = model(x, labels=y)
    assert out.logits.shape == (2, 8, 512)
    assert torch.isfinite(out.loss)
    out.loss.backward()


def test_model_kv_cache_consistency():
    """Generation with KV-cache must match generation without it."""
    config = cfg()
    config["looping"]["enabled"] = False  # determinism for the equivalence check
    model = Chimera51ForCausalLM(config).eval()
    model.prepare_for_inference()

    prompt = torch.randint(0, 512, (1, 4))
    with torch.inference_mode():
        # No-cache: feed the full sequence each time.
        cur = prompt.clone()
        no_cache_tokens = []
        for _ in range(3):
            out = model(cur, logits_to_keep=1)
            tok = out.logits[:, -1].argmax(-1, keepdim=True)
            cur = torch.cat([cur, tok], dim=1)
            no_cache_tokens.append(int(tok.item()))

        # KV-cache: feed only the new token after the first call.
        out = model(prompt, use_cache=True, logits_to_keep=1)
        caches = out.caches
        tok = out.logits[:, -1].argmax(-1, keepdim=True)
        cache_tokens = [int(tok.item())]
        for _ in range(2):
            out = model(tok, caches=caches, use_cache=True, logits_to_keep=1)
            caches = out.caches
            tok = out.logits[:, -1].argmax(-1, keepdim=True)
            cache_tokens.append(int(tok.item()))

    assert no_cache_tokens == cache_tokens


def test_moe_and_span_bank_shapes():
    moe = MoELayer(32, 64, n_routed_experts=3, n_shared_experts=1, num_experts_per_tok=2)
    x = torch.randn(2, 4, 32)
    assert moe(x).shape == x.shape
    bank = SpanBank(max_entries=8, hidden_size=32)
    bank.add(torch.randn(3, 32), torch.randn(3, 32))
    assert bank.query(torch.randn(5, 32)).shape == (5, 32)


def test_tokenizer_fallback_roundtrip():
    tok = ChimeraTokenizer(vocab_size=512)
    text = "hello cpu"
    assert tok.decode(tok.encode(text)) == text