| """Tests for the adaptive curriculum manager + prompt-schedule helper."""
|
|
|
| from __future__ import annotations
|
|
|
| import pytest
|
|
|
| from training.curriculum import CurriculumConfig, CurriculumManager, TIERS
|
| from training.training_script import (
|
| DEFAULT_CURRICULUM_SCHEDULE,
|
| curriculum_difficulty_for,
|
| )
|
|
|
|
|
| def test_schedule_ramps_easy_to_hard():
|
| n = 100
|
| diffs = [curriculum_difficulty_for(i, n) for i in range(n)]
|
|
|
| assert diffs[0] == "easy"
|
| assert diffs[-1] == "hard"
|
|
|
| assert "easy" in diffs and "medium" in diffs and "hard" in diffs
|
|
|
|
|
| def test_schedule_respects_custom_phases():
|
| n = 100
|
| schedule = [("easy", 0.2), ("medium", 0.6), ("hard", 0.2)]
|
| diffs = [curriculum_difficulty_for(i, n, schedule) for i in range(n)]
|
| assert diffs.count("easy") == 20
|
| assert diffs.count("medium") == 60
|
| assert diffs.count("hard") == 20
|
|
|
|
|
| def test_curriculum_promotes_after_streak_of_successes():
|
| cm = CurriculumManager(CurriculumConfig(window=10, min_rollouts_per_tier=10,
|
| promote_threshold=0.6))
|
| assert cm.next_difficulty() == "easy"
|
| for _ in range(10):
|
| cm.record(success=True)
|
| assert cm.next_difficulty() == "medium"
|
| assert cm.state.promotions == 1
|
|
|
|
|
| def test_curriculum_demotes_after_streak_of_failures():
|
| cm = CurriculumManager(CurriculumConfig(
|
| window=10, min_rollouts_per_tier=10,
|
| start_difficulty="medium", demote_threshold=0.2,
|
| ))
|
| for _ in range(10):
|
| cm.record(success=False)
|
| assert cm.next_difficulty() == "easy"
|
| assert cm.state.demotions == 1
|
|
|
|
|
| def test_curriculum_does_not_change_before_min_rollouts():
|
| cm = CurriculumManager(CurriculumConfig(window=20, min_rollouts_per_tier=20))
|
| for _ in range(5):
|
| cm.record(success=True)
|
| assert cm.next_difficulty() == "easy"
|
|
|
|
|
| def test_curriculum_tracks_rolling_success():
|
| cm = CurriculumManager(CurriculumConfig(window=10, min_rollouts_per_tier=100))
|
| for _ in range(7):
|
| cm.record(success=True)
|
| for _ in range(3):
|
| cm.record(success=False)
|
| assert cm.rolling_success() == pytest.approx(0.7)
|
|
|
|
|
| def test_curriculum_serialises_to_dict():
|
| cm = CurriculumManager(CurriculumConfig())
|
| cm.record(success=True)
|
| snap = cm.to_dict()
|
| assert "config" in snap and "state" in snap
|
| assert snap["state"]["current"] in TIERS
|
|
|