"""Tests for the adaptive curriculum manager + prompt-schedule helper.""" from __future__ import annotations import pytest from training.curriculum import CurriculumConfig, CurriculumManager, TIERS from training.training_script import ( DEFAULT_CURRICULUM_SCHEDULE, curriculum_difficulty_for, ) def test_schedule_ramps_easy_to_hard(): n = 100 diffs = [curriculum_difficulty_for(i, n) for i in range(n)] # First episode is easy, last is hard. assert diffs[0] == "easy" assert diffs[-1] == "hard" # Each tier appears at least once. assert "easy" in diffs and "medium" in diffs and "hard" in diffs def test_schedule_respects_custom_phases(): n = 100 schedule = [("easy", 0.2), ("medium", 0.6), ("hard", 0.2)] diffs = [curriculum_difficulty_for(i, n, schedule) for i in range(n)] assert diffs.count("easy") == 20 assert diffs.count("medium") == 60 assert diffs.count("hard") == 20 def test_curriculum_promotes_after_streak_of_successes(): cm = CurriculumManager(CurriculumConfig(window=10, min_rollouts_per_tier=10, promote_threshold=0.6)) assert cm.next_difficulty() == "easy" for _ in range(10): cm.record(success=True) assert cm.next_difficulty() == "medium" assert cm.state.promotions == 1 def test_curriculum_demotes_after_streak_of_failures(): cm = CurriculumManager(CurriculumConfig( window=10, min_rollouts_per_tier=10, start_difficulty="medium", demote_threshold=0.2, )) for _ in range(10): cm.record(success=False) assert cm.next_difficulty() == "easy" assert cm.state.demotions == 1 def test_curriculum_does_not_change_before_min_rollouts(): cm = CurriculumManager(CurriculumConfig(window=20, min_rollouts_per_tier=20)) for _ in range(5): cm.record(success=True) assert cm.next_difficulty() == "easy" def test_curriculum_tracks_rolling_success(): cm = CurriculumManager(CurriculumConfig(window=10, min_rollouts_per_tier=100)) for _ in range(7): cm.record(success=True) for _ in range(3): cm.record(success=False) assert cm.rolling_success() == pytest.approx(0.7) def test_curriculum_serialises_to_dict(): cm = CurriculumManager(CurriculumConfig()) cm.record(success=True) snap = cm.to_dict() assert "config" in snap and "state" in snap assert snap["state"]["current"] in TIERS