File size: 6,417 Bytes
f28409b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""Adaptive curriculum for CERNenv training.



The hackathon participant guide (§6, §14) and the FAQs (Q14, Q35) all push

the same idea: don't start RL on the hardest distribution because if the

probability of a successful rollout is ≈ 0, the policy never sees positive

reward and learning stalls. Conversely, if the env is too easy, the

gradient signal collapses and training plateaus (RLVE, FAQ Q46).



This module exposes a tiny ``CurriculumManager`` that:



* tracks rolling success rate across rollouts,

* promotes the agent up the difficulty ladder when the policy is clearly

  succeeding (rolling success ≥ ``promote_threshold``),

* demotes the agent when the policy is repeatedly failing (rolling

  success ≤ ``demote_threshold``) so it can recover instead of burning

  compute on impossible tasks,

* exposes a ``next_difficulty()`` hook that the training loop calls

  whenever it resets the environment.



The state is intentionally minimal so it can be serialised into a JSON

checkpoint or logged into ``evidence/training_log.csv``.

"""

from __future__ import annotations

import json
from collections import deque
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Deque, Dict, List, Optional


# Difficulty tiers, in order of increasing challenge. ``None`` is treated
# as "let the sampler choose at random" and is not part of the ladder.
TIERS: List[str] = ["easy", "medium", "hard"]


@dataclass
class CurriculumConfig:
    """Knobs for the adaptive curriculum."""

    start_difficulty: str = "easy"
    window: int = 24
    promote_threshold: float = 0.55
    demote_threshold: float = 0.10
    # Minimum rollouts at the current tier before any promotion can fire.
    # Prevents flapping when the rolling window is still short.
    min_rollouts_per_tier: int = 12
    # If True, advance one tier at a time; otherwise jump to the highest
    # tier the policy clears.
    one_step_at_a_time: bool = True


@dataclass
class CurriculumState:
    """Serialisable curriculum state for logging / resume."""

    current: str = "easy"
    history: List[Dict[str, float]] = field(default_factory=list)
    rollouts_at_current: int = 0
    successes_at_current: int = 0
    promotions: int = 0
    demotions: int = 0


class CurriculumManager:
    """Rolling-window adaptive curriculum.



    Usage::



        cm = CurriculumManager(CurriculumConfig())

        for ep in range(N):

            diff = cm.next_difficulty()

            obs = env.reset(seed=ep, difficulty=diff)

            ...                      # rollout

            cm.record(success=ep_success, reward=cumulative_reward)



    The manager is single-threaded; if you parallelise rollouts wrap

    ``record`` calls in a lock at the caller.

    """

    def __init__(self, config: CurriculumConfig = CurriculumConfig()) -> None:
        if config.start_difficulty not in TIERS:
            raise ValueError(
                f"start_difficulty must be one of {TIERS}, got {config.start_difficulty!r}"
            )
        self.config = config
        self.state = CurriculumState(current=config.start_difficulty)
        self._window: Deque[float] = deque(maxlen=config.window)

    # ── public API ────────────────────────────────────────────────────

    def next_difficulty(self) -> str:
        """Difficulty to use for the next ``env.reset``."""
        return self.state.current

    def record(self, *, success: bool, reward: float = 0.0) -> Dict[str, float]:
        """Record one rollout outcome and (maybe) update the tier.



        Returns a snapshot of curriculum state useful for logging.

        """
        self._window.append(1.0 if success else 0.0)
        self.state.rollouts_at_current += 1
        self.state.successes_at_current += int(bool(success))
        snap = self._maybe_step()
        self.state.history.append(snap)
        return snap

    def to_dict(self) -> Dict:
        return {
            "config": asdict(self.config),
            "state": asdict(self.state),
            "rolling_success": self.rolling_success(),
        }

    def rolling_success(self) -> float:
        if not self._window:
            return 0.0
        return float(sum(self._window) / len(self._window))

    def save(self, path: Path) -> None:
        Path(path).write_text(json.dumps(self.to_dict(), indent=2))

    # ── internals ─────────────────────────────────────────────────────

    def _maybe_step(self) -> Dict[str, float]:
        cfg = self.config
        rs = self.rolling_success()
        snap: Dict[str, float] = {
            "current": self.state.current,
            "rolling_success": rs,
            "rollouts_at_current": float(self.state.rollouts_at_current),
        }

        if self.state.rollouts_at_current < cfg.min_rollouts_per_tier:
            return snap
        if len(self._window) < min(cfg.window, cfg.min_rollouts_per_tier):
            return snap

        idx = TIERS.index(self.state.current)

        if rs >= cfg.promote_threshold and idx < len(TIERS) - 1:
            target_idx = idx + 1 if cfg.one_step_at_a_time else len(TIERS) - 1
            self._move_to(TIERS[target_idx], promotion=True)
            snap["event"] = "promote"
            snap["current"] = self.state.current
            return snap

        if rs <= cfg.demote_threshold and idx > 0:
            target_idx = idx - 1 if cfg.one_step_at_a_time else 0
            self._move_to(TIERS[target_idx], promotion=False)
            snap["event"] = "demote"
            snap["current"] = self.state.current
            return snap

        return snap

    def _move_to(self, tier: str, *, promotion: bool) -> None:
        self.state.current = tier
        self.state.rollouts_at_current = 0
        self.state.successes_at_current = 0
        self._window.clear()
        if promotion:
            self.state.promotions += 1
        else:
            self.state.demotions += 1


__all__ = [
    "TIERS",
    "CurriculumConfig",
    "CurriculumManager",
    "CurriculumState",
]