File size: 1,163 Bytes
a2ffabc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from arm_gym.kernels import TEMPLATES, generate_all, generate_variants, split_train_eval, summary


def test_template_count_matches_audit_target():
    assert 20 <= len(TEMPLATES) <= 30


def test_variants_reach_500_plus():
    all_v = generate_all()
    assert len(all_v) >= 500, f"got {len(all_v)}"


def test_variant_ids_are_unique():
    ids = [v.variant_id for v in generate_all()]
    assert len(ids) == len(set(ids))


def test_vec_add_renders_valid_c():
    v = next(generate_variants("vec_add"))
    assert "#include <stddef.h>" in v.c_source
    assert "kernel" in v.c_source
    assert "for" in v.c_source


def test_split_is_deterministic_and_disjoint():
    vs = generate_all()
    t1, e1 = split_train_eval(vs, eval_frac=0.1, seed=42)
    t2, e2 = split_train_eval(vs, eval_frac=0.1, seed=42)
    assert {v.variant_id for v in t1} == {v.variant_id for v in t2}
    assert {v.variant_id for v in e1} == {v.variant_id for v in e2}
    overlap = {v.variant_id for v in t1} & {v.variant_id for v in e1}
    assert overlap == set()


def test_summary_shape():
    s = summary()
    assert s["templates"] == len(TEMPLATES)
    assert s["variants"] >= 500