Spaces:
Paused
Paused
File size: 2,524 Bytes
0a6c641 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | """Tests for the adaptive curriculum manager + prompt-schedule helper."""
from __future__ import annotations
import pytest
from training.curriculum import CurriculumConfig, CurriculumManager, TIERS
from training.training_script import (
DEFAULT_CURRICULUM_SCHEDULE,
curriculum_difficulty_for,
)
def test_schedule_ramps_easy_to_hard():
n = 100
diffs = [curriculum_difficulty_for(i, n) for i in range(n)]
# First episode is easy, last is hard.
assert diffs[0] == "easy"
assert diffs[-1] == "hard"
# Each tier appears at least once.
assert "easy" in diffs and "medium" in diffs and "hard" in diffs
def test_schedule_respects_custom_phases():
n = 100
schedule = [("easy", 0.2), ("medium", 0.6), ("hard", 0.2)]
diffs = [curriculum_difficulty_for(i, n, schedule) for i in range(n)]
assert diffs.count("easy") == 20
assert diffs.count("medium") == 60
assert diffs.count("hard") == 20
def test_curriculum_promotes_after_streak_of_successes():
cm = CurriculumManager(CurriculumConfig(window=10, min_rollouts_per_tier=10,
promote_threshold=0.6))
assert cm.next_difficulty() == "easy"
for _ in range(10):
cm.record(success=True)
assert cm.next_difficulty() == "medium"
assert cm.state.promotions == 1
def test_curriculum_demotes_after_streak_of_failures():
cm = CurriculumManager(CurriculumConfig(
window=10, min_rollouts_per_tier=10,
start_difficulty="medium", demote_threshold=0.2,
))
for _ in range(10):
cm.record(success=False)
assert cm.next_difficulty() == "easy"
assert cm.state.demotions == 1
def test_curriculum_does_not_change_before_min_rollouts():
cm = CurriculumManager(CurriculumConfig(window=20, min_rollouts_per_tier=20))
for _ in range(5):
cm.record(success=True)
assert cm.next_difficulty() == "easy"
def test_curriculum_tracks_rolling_success():
cm = CurriculumManager(CurriculumConfig(window=10, min_rollouts_per_tier=100))
for _ in range(7):
cm.record(success=True)
for _ in range(3):
cm.record(success=False)
assert cm.rolling_success() == pytest.approx(0.7)
def test_curriculum_serialises_to_dict():
cm = CurriculumManager(CurriculumConfig())
cm.record(success=True)
snap = cm.to_dict()
assert "config" in snap and "state" in snap
assert snap["state"]["current"] in TIERS
|