ch1mera / tests /test_chimera.py
Lgr54HFi's picture
Upload folder using huggingface_hub
6e408ce verified
import pytest
torch = pytest.importorskip("torch")
from chimera import (
Chimera51ForCausalLM, ChimeraTokenizer, load_config, scale_config,
pack_ternary, unpack_ternary,
)
from chimera.inference import SpanBank
from chimera.moe import MoELayer
from chimera.quantization import BitLinear, ternarize_weight
def cfg():
c = scale_config(load_config("config.json"), "nano")
c["vocab_size"] = 512
c["span_inference"]["enabled"] = False
return c
def test_pack_unpack_roundtrip():
q = torch.tensor([[-1, 0, 1, 1, -1, 0, 1, 0, -1]], dtype=torch.int8)
packed = pack_ternary(q)
out = unpack_ternary(packed, q.shape[-1], dtype=torch.float32).to(torch.int8)
assert torch.equal(q, out)
def test_ternarize_weight_basic():
w = torch.randn(8, 16) * 0.5
wq, alpha = ternarize_weight(w)
assert wq.shape == w.shape
assert alpha.shape == (8,)
assert (wq.unique().abs() <= 1).all()
def test_bitlinear_forward_backward_and_packed():
layer = BitLinear(7, 5)
x = torch.randn(3, 7, requires_grad=True)
y = layer(x).sum()
y.backward()
assert x.grad is not None and torch.isfinite(x.grad).all()
assert layer.weight.grad is not None
layer.prepare_for_inference()
layer.eval()
with torch.no_grad():
out = layer(torch.randn(2, 7))
assert out.shape == (2, 5)
def test_bitlinear_dense_cache_consistency():
layer = BitLinear(8, 4)
layer.eval()
layer.prepare_for_inference()
x = torch.randn(2, 8)
with torch.no_grad():
out1 = layer(x)
out2 = layer(x)
assert torch.allclose(out1, out2)
def test_model_forward_loss_and_generate_shape():
model = Chimera51ForCausalLM(cfg())
x = torch.randint(0, 512, (2, 8))
y = torch.randint(0, 512, (2, 8))
out = model(x, labels=y)
assert out.logits.shape == (2, 8, 512)
assert torch.isfinite(out.loss)
out.loss.backward()
def test_model_kv_cache_consistency():
"""Generation with KV-cache must match generation without it."""
config = cfg()
config["looping"]["enabled"] = False # determinism for the equivalence check
model = Chimera51ForCausalLM(config).eval()
model.prepare_for_inference()
prompt = torch.randint(0, 512, (1, 4))
with torch.inference_mode():
# No-cache: feed the full sequence each time.
cur = prompt.clone()
no_cache_tokens = []
for _ in range(3):
out = model(cur, logits_to_keep=1)
tok = out.logits[:, -1].argmax(-1, keepdim=True)
cur = torch.cat([cur, tok], dim=1)
no_cache_tokens.append(int(tok.item()))
# KV-cache: feed only the new token after the first call.
out = model(prompt, use_cache=True, logits_to_keep=1)
caches = out.caches
tok = out.logits[:, -1].argmax(-1, keepdim=True)
cache_tokens = [int(tok.item())]
for _ in range(2):
out = model(tok, caches=caches, use_cache=True, logits_to_keep=1)
caches = out.caches
tok = out.logits[:, -1].argmax(-1, keepdim=True)
cache_tokens.append(int(tok.item()))
assert no_cache_tokens == cache_tokens
def test_moe_and_span_bank_shapes():
moe = MoELayer(32, 64, n_routed_experts=3, n_shared_experts=1, num_experts_per_tok=2)
x = torch.randn(2, 4, 32)
assert moe(x).shape == x.shape
bank = SpanBank(max_entries=8, hidden_size=32)
bank.add(torch.randn(3, 32), torch.randn(3, 32))
assert bank.query(torch.randn(5, 32)).shape == (5, 32)
def test_tokenizer_fallback_roundtrip():
tok = ChimeraTokenizer(vocab_size=512)
text = "hello cpu"
assert tok.decode(tok.encode(text)) == text