arm-gym / tests /test_kernels.py
kaori02's picture
arm-gym: single-tree sync (HF Hub 10MiB limit; drop notebook run artifacts from history)
a2ffabc
from arm_gym.kernels import TEMPLATES, generate_all, generate_variants, split_train_eval, summary
def test_template_count_matches_audit_target():
assert 20 <= len(TEMPLATES) <= 30
def test_variants_reach_500_plus():
all_v = generate_all()
assert len(all_v) >= 500, f"got {len(all_v)}"
def test_variant_ids_are_unique():
ids = [v.variant_id for v in generate_all()]
assert len(ids) == len(set(ids))
def test_vec_add_renders_valid_c():
v = next(generate_variants("vec_add"))
assert "#include <stddef.h>" in v.c_source
assert "kernel" in v.c_source
assert "for" in v.c_source
def test_split_is_deterministic_and_disjoint():
vs = generate_all()
t1, e1 = split_train_eval(vs, eval_frac=0.1, seed=42)
t2, e2 = split_train_eval(vs, eval_frac=0.1, seed=42)
assert {v.variant_id for v in t1} == {v.variant_id for v in t2}
assert {v.variant_id for v in e1} == {v.variant_id for v in e2}
overlap = {v.variant_id for v in t1} & {v.variant_id for v in e1}
assert overlap == set()
def test_summary_shape():
s = summary()
assert s["templates"] == len(TEMPLATES)
assert s["variants"] >= 500