File size: 2,154 Bytes
f28409b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | """Tests for the curated + procedural scenario sampler."""
from __future__ import annotations
import pytest
from server.tasks.scenarios import (
CURATED_SCENARIOS,
Scenario,
sample_scenario,
)
def test_curated_scenarios_exist_and_are_unique():
names = [s.name for s in CURATED_SCENARIOS]
assert len(names) == len(set(names))
assert "easy_diphoton_160" in names
assert "higgs_like_125" in names
def test_sample_scenario_by_name_returns_correct_scenario():
s = sample_scenario(name="higgs_like_125", seed=1)
assert isinstance(s, Scenario)
assert s.name == "higgs_like_125"
assert s.latent.particle.mass_gev == pytest.approx(125.0)
def test_sample_scenario_seed_is_reproducible():
a = sample_scenario(difficulty="medium", seed=42)
b = sample_scenario(difficulty="medium", seed=42)
# procedural sampler may pick different scenarios across seeds, but
# *with the same seed* it must be deterministic at least in mass.
assert a.latent.particle.mass_gev == pytest.approx(b.latent.particle.mass_gev)
@pytest.mark.parametrize("difficulty", ["easy", "medium", "hard"])
def test_difficulty_tier_bounds_respected(difficulty):
# We sample a handful of seeds and check none escape the tier bounds.
bounds = {
"easy": (90.0, 250.0),
"medium": (100.0, 600.0),
"hard": (250.0, 1500.0),
}[difficulty]
seen_masses = []
for seed in range(50):
s = sample_scenario(difficulty=difficulty, seed=seed)
# If sampler picks a curated scenario, its mass might fall slightly
# outside the procedural bounds; allow a small tolerance.
seen_masses.append(s.latent.particle.mass_gev)
# at least some procedural samples in-range
in_range = [m for m in seen_masses if bounds[0] <= m <= bounds[1]]
assert len(in_range) > 0
def test_fresh_latent_is_independent_copy():
s = sample_scenario(name="easy_diphoton_160", seed=1)
a = s.fresh_latent()
b = s.fresh_latent()
a.resources.budget_used_musd = 99.0
assert b.resources.budget_used_musd == 0.0
|