| """Comprehensive edge-case tests for arm-gym. |
| |
| Covers edge cases not in the existing test suite: NaN/unicode/truncation in |
| errors, empty/boundary params in kernels, zero-cycles and register aliasing in |
| mca, std=0 and mismatched lengths in reward, exception handling in rollout, |
| nested JSON in multi_agent, curriculum regression in env, sigma boundary in |
| verifier, and smoke/detect in train. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import math |
| import os |
| import tempfile |
| from pathlib import Path |
| from unittest.mock import MagicMock, patch |
|
|
| import pytest |
|
|
| from arm_gym.errors import ErrorKind, StructuredError, VerifierResult |
| from arm_gym.kernels import ( |
| TEMPLATES, |
| KernelTemplate, |
| KernelVariant, |
| generate_all, |
| generate_variants, |
| split_train_eval, |
| ) |
| from arm_gym.mca import McaReport, parse_mca, uses_neon_with_liveness |
| from arm_gym.multi_agent import ( |
| OpportunityToken, |
| analyzer_credit, |
| build_optimizer_prompt, |
| parse_analyzer_output, |
| per_role_reward, |
| ) |
| from arm_gym.reward import RewardConfig, batch_rewards, group_zscore_clip, raw_reward |
| from arm_gym.rollout_budget import ( |
| TestCase, |
| run_parallel, |
| select_adversarial, |
| update_adversarial_ranks, |
| ) |
| from arm_gym.verifier import load_baseline_distribution, sigma_sanity_ok |
|
|
| |
| |
| |
|
|
|
|
| class TestErrorsEdgeCases: |
| """StructuredError and VerifierResult edge cases.""" |
|
|
| @pytest.mark.parametrize("kind", list(ErrorKind)) |
| def test_all_error_kinds_roundtrip_json(self, kind: ErrorKind): |
| """Every ErrorKind value must survive a to_prompt -> json.loads cycle.""" |
| err = StructuredError(kind=kind, message=f"test for {kind.value}") |
| payload = json.loads(err.to_prompt()) |
| assert payload["kind"] == kind.value |
| assert payload["message"] == f"test for {kind.value}" |
|
|
| def test_verifier_result_nan_speedup(self): |
| """NaN speedup should not crash and should serialize.""" |
| v = VerifierResult(ok=True, reward=0.0, speedup=float("nan")) |
| assert math.isnan(v.speedup) |
|
|
| def test_structured_error_empty_message(self): |
| """Empty message should be omitted from prompt JSON.""" |
| err = StructuredError(kind=ErrorKind.TIMEOUT, message="") |
| payload = json.loads(err.to_prompt()) |
| assert "message" not in payload |
| assert payload["kind"] == "timeout" |
|
|
| def test_structured_error_unicode_message(self): |
| """Unicode chars in message should survive roundtrip.""" |
| msg = "invalid token → éñ \U0001f4a5" |
| err = StructuredError(kind=ErrorKind.ASSEMBLE_FAIL, message=msg) |
| payload = json.loads(err.to_prompt()) |
| assert payload["message"] == msg |
|
|
| def test_structured_error_very_long_message(self): |
| """Very long messages should not crash to_prompt (no built-in truncation).""" |
| msg = "A" * 100_000 |
| err = StructuredError(kind=ErrorKind.SEGFAULT, message=msg) |
| payload = json.loads(err.to_prompt()) |
| |
| assert len(payload["message"]) == 100_000 |
|
|
| def test_structured_error_expected_actual_complex(self): |
| """Expected/actual can be any serializable type.""" |
| err = StructuredError( |
| kind=ErrorKind.OUTPUT_MISMATCH, |
| message="mismatch", |
| expected=[1, 2, 3], |
| actual={"x": 42}, |
| ) |
| payload = json.loads(err.to_prompt()) |
| assert payload["expected"] == [1, 2, 3] |
| assert payload["actual"] == {"x": 42} |
|
|
| def test_verifier_result_all_none_optionals(self): |
| """VerifierResult with all optional fields None.""" |
| v = VerifierResult(ok=False, reward=0.0) |
| assert v.error is None |
| assert v.agent_cycles is None |
| assert v.baseline_cycles is None |
| assert v.speedup is None |
| assert v.mca_dispatch_stalls is None |
| assert v.mca_resource_pressure_p99 is None |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestKernelsEdgeCases: |
|
|
| def test_variant_id_stability(self): |
| """Same template + params must produce identical variant_id.""" |
| v1 = KernelVariant( |
| template_name="vec_add", |
| params=(("dtype", "float"), ("n", 64)), |
| c_source="...", |
| ) |
| v2 = KernelVariant( |
| template_name="vec_add", |
| params=(("dtype", "float"), ("n", 64)), |
| c_source="...", |
| ) |
| assert v1.variant_id == v2.variant_id |
|
|
| def test_variant_id_changes_with_params(self): |
| """Different params must produce different variant_id.""" |
| v1 = KernelVariant("vec_add", (("dtype", "float"), ("n", 64)), "...") |
| v2 = KernelVariant("vec_add", (("dtype", "float"), ("n", 128)), "...") |
| assert v1.variant_id != v2.variant_id |
|
|
| def test_split_eval_frac_zero(self): |
| """eval_frac=0 should put everything in train, nothing in eval.""" |
| variants = generate_all() |
| train, evalset = split_train_eval(variants, eval_frac=0.0, seed=0) |
| assert len(evalset) == 0 |
| assert len(train) == len(variants) |
|
|
| def test_split_eval_frac_one(self): |
| """eval_frac=1.0 should put everything in eval, nothing in train.""" |
| variants = generate_all() |
| train, evalset = split_train_eval(variants, eval_frac=1.0, seed=0) |
| assert len(train) == 0 |
| assert len(evalset) == len(variants) |
|
|
| def test_split_different_seeds_differ(self): |
| """Different seeds should produce different splits.""" |
| variants = generate_all() |
| _, e1 = split_train_eval(variants, eval_frac=0.1, seed=0) |
| _, e2 = split_train_eval(variants, eval_frac=0.1, seed=999) |
| ids1 = {v.variant_id for v in e1} |
| ids2 = {v.variant_id for v in e2} |
| |
| assert ids1 != ids2 |
|
|
| def test_all_templates_render_valid_c_with_kernel(self): |
| """Every template's first variant should contain a 'kernel' function.""" |
| for name in TEMPLATES: |
| v = next(generate_variants(name)) |
| assert "kernel" in v.c_source, f"{name} missing 'kernel' in rendered C" |
|
|
| def test_softmax_numer_has_no_dtype_param(self): |
| """softmax_numer only takes n (float-only). No 'dtype' in params.""" |
| t = TEMPLATES["softmax_numer"] |
| assert "dtype" not in t.params |
| assert "n" in t.params |
|
|
| def test_softmax_numer_renders_float_only(self): |
| """softmax_numer should always use float type.""" |
| for v in generate_variants("softmax_numer"): |
| assert "float" in v.c_source |
| break |
|
|
| def test_empty_template_params_edge(self): |
| """A template with empty param dict should yield exactly one variant.""" |
| t = KernelTemplate( |
| name="noop", |
| difficulty=1, |
| params={}, |
| render=lambda: "void kernel(void) {}", |
| ) |
| |
| import itertools |
| keys = list(t.params.keys()) |
| combos = list(itertools.product(*(t.params[k] for k in keys))) |
| assert len(combos) == 1 |
| assert combos[0] == () |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestMcaEdgeCases: |
|
|
| def test_parse_mca_missing_ipc_line(self): |
| """When IPC line is missing, IPC should be computed as instructions/cycles.""" |
| out = """\ |
| Instructions: 10 |
| Total Cycles: 5 |
| """ |
| rep = parse_mca(out) |
| assert rep.total_cycles == 5 |
| assert rep.instructions == 10 |
| assert abs(rep.ipc - 2.0) < 0.01 |
|
|
| def test_parse_mca_zero_cycles_no_div_by_zero(self): |
| """Zero cycles must not cause ZeroDivisionError in IPC fallback.""" |
| out = """\ |
| Instructions: 10 |
| Total Cycles: 0 |
| """ |
| rep = parse_mca(out) |
| assert rep.total_cycles == 0 |
| |
| assert abs(rep.ipc - 10.0) < 0.01 |
|
|
| def test_parse_mca_missing_both_raises(self): |
| """Missing both Instructions and Total Cycles should raise RuntimeError.""" |
| with pytest.raises(RuntimeError, match="could not parse"): |
| parse_mca("nothing useful here") |
|
|
| def test_parse_mca_empty_resource_pressure_table(self): |
| """Empty resource pressure table should yield p99 = 0.0.""" |
| out = """\ |
| Instructions: 10 |
| Total Cycles: 5 |
| IPC: 2.0 |
| Resource pressure per iteration: |
| |
| """ |
| rep = parse_mca(out) |
| assert rep.resource_pressure_p99 == 0.0 |
|
|
| def test_uses_neon_with_liveness_empty_string(self): |
| """Empty assembly string should return False.""" |
| assert uses_neon_with_liveness("") is False |
|
|
| def test_uses_neon_with_liveness_only_comments(self): |
| """Assembly with only comments should return False.""" |
| asm = """\ |
| // comment1 |
| ; comment2 |
| # comment3 |
| """ |
| assert uses_neon_with_liveness(asm) is False |
|
|
| def test_uses_neon_with_liveness_only_directives(self): |
| """Assembly with only directives should return False.""" |
| asm = """\ |
| .text |
| .globl kernel |
| .type kernel, %function |
| .size kernel, .-kernel |
| """ |
| assert uses_neon_with_liveness(asm) is False |
|
|
| def test_uses_neon_register_aliasing_q0_v0(self): |
| """q0 and v0 alias the same NEON register (register 0). |
| |
| A write via v0 consumed by a store via q0 should be detected as live. |
| """ |
| asm = """\ |
| fmov v0.4s, #1.0 |
| str q0, [x0] |
| ret |
| """ |
| assert uses_neon_with_liveness(asm) is True |
|
|
| def test_uses_neon_register_aliasing_d0_v0(self): |
| """d0 and v0 share the same underlying register number.""" |
| asm = """\ |
| fmov d0, #1.0 |
| str d0, [x0] |
| ret |
| """ |
| assert uses_neon_with_liveness(asm) is True |
|
|
| def test_uses_neon_mixed_neon_and_scalar(self): |
| """Mixed NEON and scalar instructions. NEON is live-out via store.""" |
| asm = """\ |
| mov w0, #42 |
| fmov v1.4s, #2.0 |
| add w0, w0, #1 |
| st1 {v1.4s}, [x1] |
| ret |
| """ |
| assert uses_neon_with_liveness(asm) is True |
|
|
| def test_uses_neon_mixed_neon_dead_scalar_live(self): |
| """NEON reg written but never consumed; scalar ops only are live.""" |
| asm = """\ |
| fmov v3.4s, #1.0 |
| mov w0, #42 |
| ret |
| """ |
| assert uses_neon_with_liveness(asm) is False |
|
|
| def test_mca_report_no_pipeline_hazard_property(self): |
| """Boundary: stalls=0, p99=0.99 -> no hazard. p99=1.0 -> hazard.""" |
| r1 = McaReport(10, 10, 1.0, 0, 0.99, "") |
| assert r1.no_pipeline_hazard is True |
|
|
| r2 = McaReport(10, 10, 1.0, 0, 1.0, "") |
| assert r2.no_pipeline_hazard is False |
|
|
| r3 = McaReport(10, 10, 1.0, 1, 0.5, "") |
| assert r3.no_pipeline_hazard is False |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestRewardEdgeCases: |
|
|
| def _ok_result(self, speedup=1.5, stalls=0, p99=0.5): |
| return VerifierResult( |
| ok=True, reward=0.0, speedup=speedup, |
| agent_cycles=100, baseline_cycles=100 * (speedup if speedup else 1), |
| mca_dispatch_stalls=stalls, mca_resource_pressure_p99=p99, |
| ) |
|
|
| def test_raw_reward_speedup_none_on_ok_true(self): |
| """ok=True but speedup=None should yield 0.0 (guard clause).""" |
| v = VerifierResult(ok=True, reward=0.0, speedup=None) |
| assert raw_reward(v, RewardConfig()) == 0.0 |
|
|
| def test_raw_reward_speedup_nan(self): |
| """NaN speedup on ok=True. Python max/min return the non-NaN operand.""" |
| v = VerifierResult(ok=True, reward=0.0, speedup=float("nan")) |
| r = raw_reward(v, RewardConfig()) |
| assert isinstance(r, float) |
|
|
| def test_group_zscore_clip_all_identical(self): |
| """All identical values -> std=0 -> uses 1e-6 fallback. Result near 0.""" |
| out = group_zscore_clip([0.5, 0.5, 0.5, 0.5]) |
| |
| assert all(abs(r) < 1e-3 for r in out) |
|
|
| def test_group_zscore_clip_single_element(self): |
| """Single element -> len < 2, returned as-is.""" |
| out = group_zscore_clip([42.0]) |
| assert out == [42.0] |
|
|
| def test_group_zscore_clip_empty_list(self): |
| """Empty list -> len < 2, returned as-is.""" |
| out = group_zscore_clip([]) |
| assert out == [] |
|
|
| def test_batch_rewards_mismatched_lengths_guard(self): |
| """batch_rewards with asms shorter than results. |
| |
| zip() silently truncates, so result should have length of shorter list. |
| This tests the actual behavior (no explicit guard). |
| """ |
| results = [self._ok_result(1.5), self._ok_result(1.2), self._ok_result(1.0)] |
| asms = ["asm1"] |
| |
| out = batch_rewards(results, RewardConfig(), asms=asms) |
| assert len(out) == 1 |
|
|
| def test_shaped_mode_without_asm_none(self): |
| """shaped mode with asm=None skips NEON bonus entirely.""" |
| cfg = RewardConfig(mode="shaped") |
| v = self._ok_result(1.5, stalls=0, p99=0.5) |
| r = raw_reward(v, cfg, asm=None) |
| |
| assert abs(r - 0.6) < 0.01 |
|
|
| def test_batch_rewards_with_none_asms(self): |
| """batch_rewards with asms=None should work (default path).""" |
| results = [self._ok_result(1.5), self._ok_result(1.2)] |
| out = batch_rewards(results, RewardConfig(), asms=None) |
| assert len(out) == 2 |
| assert all(-1.5 <= r <= 1.5 for r in out) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestRolloutBudgetEdgeCases: |
|
|
| def test_select_adversarial_empty_list(self): |
| """Empty test list should return empty.""" |
| assert select_adversarial([]) == [] |
|
|
| def test_select_adversarial_n_greater_than_list(self): |
| """n > len(tests) should return all tests (sorted by rank).""" |
| tests = [TestCase((), None, adversarial_rank=r) for r in [0.1, 0.5]] |
| picked = select_adversarial(tests, n=100) |
| assert len(picked) == 2 |
| assert picked[0].adversarial_rank >= picked[1].adversarial_rank |
|
|
| def test_run_parallel_with_raising_callable(self): |
| """Callable that raises should count as failure.""" |
| def explode(t): |
| raise ValueError("boom") |
|
|
| tests = [TestCase((), None) for _ in range(3)] |
| ok, idx = run_parallel(explode, tests) |
| assert ok is False |
| assert 0 <= idx < 3 |
|
|
| def test_run_parallel_empty_test_list(self): |
| """Empty test list should return (True, -1).""" |
| ok, idx = run_parallel(lambda t: True, []) |
| assert ok is True |
| assert idx == -1 |
|
|
| def test_update_adversarial_ranks_out_of_bounds_index(self): |
| """Out-of-bounds failure index should be silently ignored.""" |
| tests = [TestCase((), None, adversarial_rank=1.0)] |
| |
| update_adversarial_ranks(tests, failures_by_index=[5, -2, 100]) |
| |
| assert tests[0].adversarial_rank == pytest.approx(0.9) |
|
|
| def test_update_adversarial_ranks_with_valid_index(self): |
| """Valid index gets decay + boost.""" |
| tests = [TestCase((), None, adversarial_rank=1.0)] |
| update_adversarial_ranks(tests, failures_by_index=[0]) |
| |
| assert tests[0].adversarial_rank == pytest.approx(1.9) |
|
|
| def test_concurrent_run_parallel_correctness(self): |
| """Multiple run_parallel invocations should not interfere.""" |
|
|
| def slow_pass(t): |
| return True |
|
|
| tests_a = [TestCase((), None) for _ in range(10)] |
| tests_b = [TestCase((), None) for _ in range(10)] |
|
|
| |
| ok_a, _ = run_parallel(slow_pass, tests_a) |
| ok_b, _ = run_parallel(slow_pass, tests_b) |
| assert ok_a is True |
| assert ok_b is True |
|
|
| def test_select_adversarial_stable_sort_on_equal_ranks(self): |
| """When all ranks are equal, ordering is stable (by sorted).""" |
| tests = [TestCase((i,), None, adversarial_rank=1.0) for i in range(5)] |
| picked = select_adversarial(tests, n=5) |
| assert len(picked) == 5 |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestMultiAgentEdgeCases: |
|
|
| def test_parse_analyzer_output_nested_json(self): |
| """Nested JSON objects inside the opportunity tokens.""" |
| text = '[{"kind":"vectorize_neon","rationale":"nested: {\\"x\\": 1}"}]' |
| toks = parse_analyzer_output(text) |
| assert len(toks) == 1 |
| assert toks[0].kind == "vectorize_neon" |
|
|
| def test_parse_analyzer_output_empty_array(self): |
| """Empty JSON array should return empty list.""" |
| toks = parse_analyzer_output("[]") |
| assert toks == [] |
|
|
| def test_parse_analyzer_output_array_non_dict_items(self): |
| """Array containing non-dict items should be skipped.""" |
| text = '[42, "string", {"kind":"use_ldp_stp"}, null, true]' |
| toks = parse_analyzer_output(text) |
| assert len(toks) == 1 |
| assert toks[0].kind == "use_ldp_stp" |
|
|
| def test_parse_analyzer_output_dict_missing_kind(self): |
| """Dict without 'kind' key should be skipped.""" |
| text = '[{"target_region":"loop"}, {"kind":"unroll_loop"}]' |
| toks = parse_analyzer_output(text) |
| assert len(toks) == 1 |
| assert toks[0].kind == "unroll_loop" |
|
|
| def test_build_optimizer_prompt_empty_tokens(self): |
| """Empty tokens list should still produce valid prompt.""" |
| prompt = build_optimizer_prompt("int k(){}", ".text\nret", []) |
| assert "C Code:" in prompt |
| assert "Opportunity Tokens:" in prompt |
| assert "[]" in prompt |
| assert "<assembly>" in prompt |
|
|
| def test_analyzer_credit_empty_optimizer_asm(self): |
| """Empty optimizer_asm should yield 0.0 credit.""" |
| toks = [OpportunityToken("use_ldp_stp")] |
| assert analyzer_credit(toks, "") == 0.0 |
|
|
| def test_analyzer_credit_empty_tokens(self): |
| """Empty tokens list should yield 0.0.""" |
| assert analyzer_credit([], "ldp x0, x1, [sp]") == 0.0 |
|
|
| def test_analyzer_credit_tokens_with_no_mnemonics(self): |
| """Tokens whose kind has empty mnemonic tuple get skipped (counted=0).""" |
| toks = [OpportunityToken("unroll_loop"), OpportunityToken("hoist_constant")] |
| assert analyzer_credit(toks, "ldp x0, x1, [sp]") == 0.0 |
|
|
| def test_per_role_reward_zero_optimizer_reward(self): |
| """Zero optimizer_reward should give zero to both roles.""" |
| toks = [OpportunityToken("use_ldp_stp")] |
| asm = "ldp x0, x1, [sp]\nret" |
| split = per_role_reward(0.0, toks, asm) |
| assert split["optimizer"] == 0.0 |
| assert split["analyzer"] == 0.0 |
|
|
| def test_per_role_reward_negative_optimizer_reward(self): |
| """Negative optimizer_reward should propagate correctly.""" |
| toks = [OpportunityToken("use_ldp_stp")] |
| asm = "ldp x0, x1, [sp]\nret" |
| split = per_role_reward(-1.0, toks, asm) |
| assert split["optimizer"] == -1.0 |
| |
| assert split["analyzer"] == pytest.approx(-0.3) |
|
|
| def test_opportunity_token_to_json(self): |
| """OpportunityToken.to_json should return a dict with all fields.""" |
| tok = OpportunityToken( |
| kind="vectorize_neon", target_region="inner loop", rationale="wide" |
| ) |
| d = tok.to_json() |
| assert d == { |
| "kind": "vectorize_neon", |
| "target_region": "inner loop", |
| "rationale": "wide", |
| } |
|
|
| def test_opportunity_token_to_json_none_fields(self): |
| """Optional fields should be None in output dict.""" |
| tok = OpportunityToken(kind="unroll_loop") |
| d = tok.to_json() |
| assert d["kind"] == "unroll_loop" |
| assert d["target_region"] is None |
| assert d["rationale"] is None |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestEnvEdgeCases: |
|
|
| def test_curriculum_update_advancing_stage(self): |
| """Curriculum should advance when pass_ema exceeds threshold.""" |
| from arm_gym.env import Curriculum |
|
|
| cur = Curriculum(stage=1, max_stage=4, pass_ema=0.0, threshold=0.8, alpha=1.0) |
| |
| cur.update(True) |
| assert cur.stage == 2 |
| assert cur.pass_ema == 0.0 |
|
|
| def test_curriculum_no_advance_past_max(self): |
| """Curriculum should not advance past max_stage.""" |
| from arm_gym.env import Curriculum |
|
|
| cur = Curriculum(stage=4, max_stage=4, pass_ema=0.79, threshold=0.8, alpha=1.0) |
| cur.update(True) |
| assert cur.stage == 4 |
|
|
| def test_curriculum_regress(self): |
| """Curriculum should regress when pass_ema drops below 0.2.""" |
| from arm_gym.env import Curriculum |
|
|
| cur = Curriculum(stage=3, pass_ema=0.1, threshold=0.8) |
| cur.regress() |
| assert cur.stage == 2 |
| assert cur.pass_ema == pytest.approx(0.4) |
|
|
| def test_curriculum_no_regress_from_stage_1(self): |
| """Curriculum should not regress below stage 1.""" |
| from arm_gym.env import Curriculum |
|
|
| cur = Curriculum(stage=1, pass_ema=0.0) |
| cur.regress() |
| assert cur.stage == 1 |
|
|
| def test_curriculum_no_regress_when_ema_sufficient(self): |
| """Curriculum should not regress when pass_ema >= 0.2.""" |
| from arm_gym.env import Curriculum |
|
|
| cur = Curriculum(stage=3, pass_ema=0.3) |
| cur.regress() |
| assert cur.stage == 3 |
|
|
| def test_curriculum_filter_no_matching_variants(self): |
| """Filter with stage 0 returns the warm-start whitelist only. |
| |
| Warm-start whitelist = {(vec_add, n=16), (saxpy, n=32), (relu, n=64)}. |
| At stage 0 only those variants must be eligible — never higher-difficulty |
| kernels (no leakage past the warm-start gate). |
| """ |
| from arm_gym.env import Curriculum |
|
|
| cur = Curriculum(stage=0) |
| all_v = generate_all() |
| filtered = cur.filter(all_v) |
| assert len(filtered) > 0 |
| whitelist = {n for n, _ in cur.WARMSTART_VARIANTS} |
| for v in filtered: |
| assert v.template_name in whitelist |
| params = dict(v.params) |
| target_n = next(n for name, n in cur.WARMSTART_VARIANTS if name == v.template_name) |
| assert params.get("n") == target_n |
|
|
| def test_curriculum_filter_stage_1(self): |
| """Stage 1 should only include difficulty 1 variants.""" |
| from arm_gym.env import Curriculum |
|
|
| cur = Curriculum(stage=1) |
| all_v = generate_all() |
| filtered = cur.filter(all_v) |
| for v in filtered: |
| assert TEMPLATES[v.template_name].difficulty <= 1 |
|
|
| def test_env_reset_returns_valid_observation(self): |
| """reset() should return a CompilerObservation with required fields.""" |
| from arm_gym.env import ARMGymEnv |
|
|
| with patch("arm_gym.env.detect_toolchain") as mock_tc, \ |
| patch("arm_gym.env.compile_to_asm", return_value=".text\nret"), \ |
| patch("arm_gym.mca.run_mca") as mock_mca: |
| mock_tc.return_value = MagicMock( |
| clang="clang", gcc_aarch64="aarch64-linux-gnu-gcc", |
| mca="llvm-mca", mcpu="neoverse-v2", mcpu_disclosed=None, |
| ready=lambda: True, |
| ) |
| mock_mca.return_value = McaReport( |
| total_cycles=100, instructions=50, ipc=2.0, |
| dispatch_stalls=0, resource_pressure_p99=0.3, raw="...", |
| ) |
| env = ARMGymEnv(variants=generate_all()) |
| obs = env.reset(seed=42) |
| assert obs.done is False |
| assert obs.variant_id |
| assert obs.c_source |
| assert obs.baseline_asm |
| assert obs.baseline_cycles > 0 |
| assert obs.difficulty >= 1 |
|
|
| def test_env_step_with_unknown_variant_id(self): |
| """step() with an unknown variant_id should return error observation.""" |
| from arm_gym.env import ARMGymEnv, CompilerAction |
|
|
| with patch("arm_gym.env.detect_toolchain") as mock_tc: |
| mock_tc.return_value = MagicMock( |
| clang="clang", gcc_aarch64="aarch64-linux-gnu-gcc", |
| mca="llvm-mca", mcpu="neoverse-v2", mcpu_disclosed=None, |
| ready=lambda: True, |
| ) |
| env = ARMGymEnv(variants=generate_all()) |
| action = CompilerAction(variant_id="nonexistent_abc123", assembly="ret") |
| obs = env.step(action) |
| assert obs.done is True |
| assert obs.reward == 0.0 |
| assert obs.error_json is not None |
| assert "unknown_variant" in obs.error_json |
|
|
| def test_env_metadata_returns_required_fields(self): |
| """metadata() should contain all expected keys.""" |
| from arm_gym.env import ARMGymEnv |
|
|
| with patch("arm_gym.env.detect_toolchain") as mock_tc: |
| mock_tc.return_value = MagicMock( |
| clang="clang", gcc_aarch64="aarch64-linux-gnu-gcc", |
| mca="llvm-mca", mcpu="neoverse-v2", mcpu_disclosed=None, |
| ready=lambda: True, |
| ) |
| env = ARMGymEnv(variants=generate_all()) |
| meta = env.metadata() |
| required_keys = { |
| "name", "version", "description", "supports_concurrent_sessions", |
| "templates", "variants", "curriculum_stages", "reward_mode", "mcpu", |
| } |
| assert required_keys.issubset(meta.keys()) |
| assert meta["name"] == "arm-gym" |
|
|
| def test_env_state_before_any_reset(self): |
| """state before reset should have None values for episode-specific fields.""" |
| from arm_gym.env import ARMGymEnv |
|
|
| with patch("arm_gym.env.detect_toolchain") as mock_tc: |
| mock_tc.return_value = MagicMock( |
| clang="clang", gcc_aarch64="aarch64-linux-gnu-gcc", |
| mca="llvm-mca", mcpu="neoverse-v2", mcpu_disclosed=None, |
| ready=lambda: True, |
| ) |
| env = ARMGymEnv(variants=generate_all()) |
| s = env.state |
| assert s["episode_id"] is None |
| assert s["variant_id"] is None |
| assert s["kernel_name"] is None |
| assert s["step_count"] == 0 |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestVerifierEdgeCases: |
|
|
| def test_sigma_sanity_ok_exactly_at_3sigma(self): |
| """Exactly at mean + 3*std should pass (<=).""" |
| dist = {"mean": 1.0, "std": 0.1} |
| speedup = 1.0 + 3.0 * 0.1 |
| assert sigma_sanity_ok(speedup, dist) is True |
|
|
| def test_sigma_sanity_ok_just_above_3sigma(self): |
| """Just above mean + 3*std should fail.""" |
| dist = {"mean": 1.0, "std": 0.1} |
| speedup = 1.0 + 3.0 * 0.1 + 0.001 |
| assert sigma_sanity_ok(speedup, dist) is False |
|
|
| def test_sigma_sanity_ok_just_below_3sigma(self): |
| """Just below mean + 3*std should pass.""" |
| dist = {"mean": 1.0, "std": 0.1} |
| speedup = 1.0 + 3.0 * 0.1 - 0.001 |
| assert sigma_sanity_ok(speedup, dist) is True |
|
|
| def test_sigma_sanity_ok_zero_std(self): |
| """Zero std triggers max(std, 1e-6) guard.""" |
| dist = {"mean": 1.0, "std": 0.0} |
| |
| assert sigma_sanity_ok(1.000002, dist) is True |
| assert sigma_sanity_ok(1.1, dist) is False |
|
|
| def test_sigma_sanity_ok_missing_keys_uses_defaults(self): |
| """Missing mean/std should use defaults (1.0, 0.1).""" |
| dist = {} |
| |
| assert sigma_sanity_ok(1.3, dist) is True |
| assert sigma_sanity_ok(1.4, dist) is False |
|
|
| def test_load_baseline_distribution_missing_file(self): |
| """Missing file should return None.""" |
| result = load_baseline_distribution("nonexistent_variant_xyz") |
| assert result is None |
|
|
| def test_load_baseline_distribution_corrupt_json(self): |
| """Corrupt JSON file should return None (exception caught).""" |
| with tempfile.TemporaryDirectory() as d: |
| p = Path(d) / "corrupt.json" |
| p.write_text("{broken json///") |
| with patch("arm_gym.verifier.BASELINE_DIST_DIR", Path(d)): |
| result = load_baseline_distribution("corrupt") |
| assert result is None |
|
|
| def test_load_baseline_distribution_valid_json(self): |
| """Valid JSON file should return parsed dict.""" |
| with tempfile.TemporaryDirectory() as d: |
| data = {"mean": 1.2, "std": 0.15, "n": 100} |
| p = Path(d) / "good_variant.json" |
| p.write_text(json.dumps(data)) |
| with patch("arm_gym.verifier.BASELINE_DIST_DIR", Path(d)): |
| result = load_baseline_distribution("good_variant") |
| assert result == data |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestTrainEdgeCases: |
|
|
| def test_train_config_lora_alpha_is_rank_times_2(self): |
| """lora_alpha property should be lora_rank * 2.""" |
| from arm_gym.train import TrainConfig |
|
|
| cfg = TrainConfig(lora_rank=16) |
| assert cfg.lora_alpha == 32 |
|
|
| cfg2 = TrainConfig(lora_rank=32) |
| assert cfg2.lora_alpha == 64 |
|
|
| cfg3 = TrainConfig(lora_rank=1) |
| assert cfg3.lora_alpha == 2 |
|
|
| def test_detect_stack_without_torch(self): |
| """Without torch/unsloth/vllm, detect_stack should return 'single_gpu'.""" |
| from arm_gym.train import detect_stack |
|
|
| with patch("importlib.util.find_spec", return_value=None): |
| with patch.dict(os.environ, {}, clear=True): |
| assert detect_stack() == "single_gpu" |
|
|
| def test_detect_stack_with_unsloth_and_vllm(self): |
| """With both unsloth and vllm, detect_stack should return 'unsloth_vllm'.""" |
| from arm_gym.train import detect_stack |
|
|
| def fake_find_spec(name): |
| if name in ("unsloth", "vllm"): |
| return MagicMock() |
| return None |
|
|
| with patch("importlib.util.find_spec", side_effect=fake_find_spec): |
| assert detect_stack() == "unsloth_vllm" |
|
|
| def test_detect_stack_plain_trl_ddp(self): |
| """With torch and WORLD_SIZE env var, detect_stack returns 'plain_trl_ddp'.""" |
| from arm_gym.train import detect_stack |
|
|
| def fake_find_spec(name): |
| if name == "torch": |
| return MagicMock() |
| return None |
|
|
| with patch("importlib.util.find_spec", side_effect=fake_find_spec): |
| with patch.dict(os.environ, {"WORLD_SIZE": "4"}): |
| assert detect_stack() == "plain_trl_ddp" |
|
|
| def test_main_smoke_returns_zero(self): |
| """main(["--smoke"]) should return 0 without needing GPU.""" |
| from arm_gym.train import main |
|
|
| assert main(["--smoke"]) == 0 |
|
|
| def test_main_smoke_skip_sft(self): |
| """main(["--smoke", "--skip-sft"]) should also return 0.""" |
| from arm_gym.train import main |
|
|
| assert main(["--smoke", "--skip-sft"]) == 0 |
|
|
| def test_smoke_day0_correctness_default(self): |
| """Default smoke_day0_correctness returns 0.6 (above SFT threshold).""" |
| from arm_gym.train import smoke_day0_correctness |
|
|
| rate = smoke_day0_correctness(None, "Qwen/Qwen2.5-Coder-7B-Instruct") |
| assert rate == 0.6 |
|
|
| def test_smoke_day0_correctness_forced_sft(self): |
| """FORCE_SFT_WARMUP=1 makes smoke_day0_correctness return 0.0.""" |
| from arm_gym.train import smoke_day0_correctness |
|
|
| with patch.dict(os.environ, {"FORCE_SFT_WARMUP": "1"}): |
| rate = smoke_day0_correctness(None, "model-id") |
| assert rate == 0.0 |
|
|
| def test_maybe_sft_warmup_skips_when_above_threshold(self): |
| """maybe_sft_warmup should return False when correctness >= 0.4.""" |
| from arm_gym.train import TrainConfig, maybe_sft_warmup |
|
|
| cfg = TrainConfig(skip_sft=False) |
| assert maybe_sft_warmup(cfg, None) is False |
|
|
| def test_maybe_sft_warmup_skips_when_skip_sft_flag(self): |
| """maybe_sft_warmup should return False when skip_sft=True.""" |
| from arm_gym.train import TrainConfig, maybe_sft_warmup |
|
|
| cfg = TrainConfig(skip_sft=True) |
| with patch.dict(os.environ, {"FORCE_SFT_WARMUP": "1"}): |
| assert maybe_sft_warmup(cfg, None) is False |
|
|
| def test_train_config_defaults(self): |
| """Verify scaffold defaults match CLAUDE.md documentation.""" |
| from arm_gym.train import TrainConfig |
|
|
| cfg = TrainConfig() |
| assert cfg.lora_rank == 16 |
| assert cfg.lora_alpha == 32 |
| assert cfg.num_generations == 4 |
| assert cfg.per_device_train_batch_size == 1 |
| assert cfg.gradient_accumulation_steps == 32 |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestCompileBaselineEdgeCases: |
|
|
| def test_find_tool_returns_none_when_no_candidate_found(self): |
| """find_tool with unavailable candidates should return None.""" |
| from arm_gym.compile_baseline import find_tool |
|
|
| with patch("shutil.which", return_value=None): |
| assert find_tool(["nonexistent-1", "nonexistent-2"]) is None |
|
|
| def test_find_tool_returns_first_available(self): |
| """find_tool should return the first candidate found by shutil.which.""" |
| from arm_gym.compile_baseline import find_tool |
|
|
| def fake_which(name): |
| if name == "clang-21": |
| return None |
| if name == "clang-20": |
| return "/usr/bin/clang-20" |
| return None |
|
|
| with patch("shutil.which", side_effect=fake_which): |
| assert find_tool(["clang-21", "clang-20", "clang"]) == "clang-20" |
|
|
| def test_toolchain_info_ready_with_neither(self): |
| """ToolchainInfo.ready() is False when both compilers are None.""" |
| from arm_gym.compile_baseline import ToolchainInfo |
|
|
| tc = ToolchainInfo(clang=None, gcc_aarch64=None, mca=None, |
| mcpu="neoverse-v2", mcpu_disclosed=None) |
| assert tc.ready() is False |
|
|
| def test_toolchain_info_ready_with_clang_only(self): |
| """ToolchainInfo.ready() is True with only clang.""" |
| from arm_gym.compile_baseline import ToolchainInfo |
|
|
| tc = ToolchainInfo(clang="/usr/bin/clang", gcc_aarch64=None, mca=None, |
| mcpu="neoverse-v2", mcpu_disclosed=None) |
| assert tc.ready() is True |
|
|
| def test_compile_to_asm_no_compiler_raises(self): |
| """compile_to_asm with no compilers should raise RuntimeError.""" |
| from arm_gym.compile_baseline import ToolchainInfo, compile_to_asm |
|
|
| tc = ToolchainInfo(clang=None, gcc_aarch64=None, mca=None, |
| mcpu="neoverse-v2", mcpu_disclosed=None) |
| with pytest.raises(RuntimeError, match="no AArch64 compiler available"): |
| compile_to_asm("void kernel(void){}", tc) |
|
|
| def test_detect_toolchain_without_any_tools(self): |
| """detect_toolchain with no tools installed should still return a ToolchainInfo.""" |
| from arm_gym.compile_baseline import detect_toolchain |
|
|
| with patch("shutil.which", return_value=None), \ |
| patch("arm_gym.compile_baseline.find_tool", return_value=None): |
| tc = detect_toolchain() |
| assert tc.mcpu |
| assert tc.clang is None |
| assert tc.gcc_aarch64 is None |
|
|