File size: 2,524 Bytes
f28409b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Tests for the adaptive curriculum manager + prompt-schedule helper."""

from __future__ import annotations

import pytest

from training.curriculum import CurriculumConfig, CurriculumManager, TIERS
from training.training_script import (
    DEFAULT_CURRICULUM_SCHEDULE,
    curriculum_difficulty_for,
)


def test_schedule_ramps_easy_to_hard():
    n = 100
    diffs = [curriculum_difficulty_for(i, n) for i in range(n)]
    # First episode is easy, last is hard.
    assert diffs[0] == "easy"
    assert diffs[-1] == "hard"
    # Each tier appears at least once.
    assert "easy" in diffs and "medium" in diffs and "hard" in diffs


def test_schedule_respects_custom_phases():
    n = 100
    schedule = [("easy", 0.2), ("medium", 0.6), ("hard", 0.2)]
    diffs = [curriculum_difficulty_for(i, n, schedule) for i in range(n)]
    assert diffs.count("easy") == 20
    assert diffs.count("medium") == 60
    assert diffs.count("hard") == 20


def test_curriculum_promotes_after_streak_of_successes():
    cm = CurriculumManager(CurriculumConfig(window=10, min_rollouts_per_tier=10,
                                            promote_threshold=0.6))
    assert cm.next_difficulty() == "easy"
    for _ in range(10):
        cm.record(success=True)
    assert cm.next_difficulty() == "medium"
    assert cm.state.promotions == 1


def test_curriculum_demotes_after_streak_of_failures():
    cm = CurriculumManager(CurriculumConfig(
        window=10, min_rollouts_per_tier=10,
        start_difficulty="medium", demote_threshold=0.2,
    ))
    for _ in range(10):
        cm.record(success=False)
    assert cm.next_difficulty() == "easy"
    assert cm.state.demotions == 1


def test_curriculum_does_not_change_before_min_rollouts():
    cm = CurriculumManager(CurriculumConfig(window=20, min_rollouts_per_tier=20))
    for _ in range(5):
        cm.record(success=True)
    assert cm.next_difficulty() == "easy"


def test_curriculum_tracks_rolling_success():
    cm = CurriculumManager(CurriculumConfig(window=10, min_rollouts_per_tier=100))
    for _ in range(7):
        cm.record(success=True)
    for _ in range(3):
        cm.record(success=False)
    assert cm.rolling_success() == pytest.approx(0.7)


def test_curriculum_serialises_to_dict():
    cm = CurriculumManager(CurriculumConfig())
    cm.record(success=True)
    snap = cm.to_dict()
    assert "config" in snap and "state" in snap
    assert snap["state"]["current"] in TIERS