Spaces:
Paused
Paused
| """Tests for the adaptive curriculum manager + prompt-schedule helper.""" | |
| from __future__ import annotations | |
| import pytest | |
| from training.curriculum import CurriculumConfig, CurriculumManager, TIERS | |
| from training.training_script import ( | |
| DEFAULT_CURRICULUM_SCHEDULE, | |
| curriculum_difficulty_for, | |
| ) | |
| def test_schedule_ramps_easy_to_hard(): | |
| n = 100 | |
| diffs = [curriculum_difficulty_for(i, n) for i in range(n)] | |
| # First episode is easy, last is hard. | |
| assert diffs[0] == "easy" | |
| assert diffs[-1] == "hard" | |
| # Each tier appears at least once. | |
| assert "easy" in diffs and "medium" in diffs and "hard" in diffs | |
| def test_schedule_respects_custom_phases(): | |
| n = 100 | |
| schedule = [("easy", 0.2), ("medium", 0.6), ("hard", 0.2)] | |
| diffs = [curriculum_difficulty_for(i, n, schedule) for i in range(n)] | |
| assert diffs.count("easy") == 20 | |
| assert diffs.count("medium") == 60 | |
| assert diffs.count("hard") == 20 | |
| def test_curriculum_promotes_after_streak_of_successes(): | |
| cm = CurriculumManager(CurriculumConfig(window=10, min_rollouts_per_tier=10, | |
| promote_threshold=0.6)) | |
| assert cm.next_difficulty() == "easy" | |
| for _ in range(10): | |
| cm.record(success=True) | |
| assert cm.next_difficulty() == "medium" | |
| assert cm.state.promotions == 1 | |
| def test_curriculum_demotes_after_streak_of_failures(): | |
| cm = CurriculumManager(CurriculumConfig( | |
| window=10, min_rollouts_per_tier=10, | |
| start_difficulty="medium", demote_threshold=0.2, | |
| )) | |
| for _ in range(10): | |
| cm.record(success=False) | |
| assert cm.next_difficulty() == "easy" | |
| assert cm.state.demotions == 1 | |
| def test_curriculum_does_not_change_before_min_rollouts(): | |
| cm = CurriculumManager(CurriculumConfig(window=20, min_rollouts_per_tier=20)) | |
| for _ in range(5): | |
| cm.record(success=True) | |
| assert cm.next_difficulty() == "easy" | |
| def test_curriculum_tracks_rolling_success(): | |
| cm = CurriculumManager(CurriculumConfig(window=10, min_rollouts_per_tier=100)) | |
| for _ in range(7): | |
| cm.record(success=True) | |
| for _ in range(3): | |
| cm.record(success=False) | |
| assert cm.rolling_success() == pytest.approx(0.7) | |
| def test_curriculum_serialises_to_dict(): | |
| cm = CurriculumManager(CurriculumConfig()) | |
| cm.record(success=True) | |
| snap = cm.to_dict() | |
| assert "config" in snap and "state" in snap | |
| assert snap["state"]["current"] in TIERS | |