File size: 6,417 Bytes
f28409b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """Adaptive curriculum for CERNenv training.
The hackathon participant guide (§6, §14) and the FAQs (Q14, Q35) all push
the same idea: don't start RL on the hardest distribution because if the
probability of a successful rollout is ≈ 0, the policy never sees positive
reward and learning stalls. Conversely, if the env is too easy, the
gradient signal collapses and training plateaus (RLVE, FAQ Q46).
This module exposes a tiny ``CurriculumManager`` that:
* tracks rolling success rate across rollouts,
* promotes the agent up the difficulty ladder when the policy is clearly
succeeding (rolling success ≥ ``promote_threshold``),
* demotes the agent when the policy is repeatedly failing (rolling
success ≤ ``demote_threshold``) so it can recover instead of burning
compute on impossible tasks,
* exposes a ``next_difficulty()`` hook that the training loop calls
whenever it resets the environment.
The state is intentionally minimal so it can be serialised into a JSON
checkpoint or logged into ``evidence/training_log.csv``.
"""
from __future__ import annotations
import json
from collections import deque
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Deque, Dict, List, Optional
# Difficulty tiers, in order of increasing challenge. ``None`` is treated
# as "let the sampler choose at random" and is not part of the ladder.
TIERS: List[str] = ["easy", "medium", "hard"]
@dataclass
class CurriculumConfig:
"""Knobs for the adaptive curriculum."""
start_difficulty: str = "easy"
window: int = 24
promote_threshold: float = 0.55
demote_threshold: float = 0.10
# Minimum rollouts at the current tier before any promotion can fire.
# Prevents flapping when the rolling window is still short.
min_rollouts_per_tier: int = 12
# If True, advance one tier at a time; otherwise jump to the highest
# tier the policy clears.
one_step_at_a_time: bool = True
@dataclass
class CurriculumState:
"""Serialisable curriculum state for logging / resume."""
current: str = "easy"
history: List[Dict[str, float]] = field(default_factory=list)
rollouts_at_current: int = 0
successes_at_current: int = 0
promotions: int = 0
demotions: int = 0
class CurriculumManager:
"""Rolling-window adaptive curriculum.
Usage::
cm = CurriculumManager(CurriculumConfig())
for ep in range(N):
diff = cm.next_difficulty()
obs = env.reset(seed=ep, difficulty=diff)
... # rollout
cm.record(success=ep_success, reward=cumulative_reward)
The manager is single-threaded; if you parallelise rollouts wrap
``record`` calls in a lock at the caller.
"""
def __init__(self, config: CurriculumConfig = CurriculumConfig()) -> None:
if config.start_difficulty not in TIERS:
raise ValueError(
f"start_difficulty must be one of {TIERS}, got {config.start_difficulty!r}"
)
self.config = config
self.state = CurriculumState(current=config.start_difficulty)
self._window: Deque[float] = deque(maxlen=config.window)
# ── public API ────────────────────────────────────────────────────
def next_difficulty(self) -> str:
"""Difficulty to use for the next ``env.reset``."""
return self.state.current
def record(self, *, success: bool, reward: float = 0.0) -> Dict[str, float]:
"""Record one rollout outcome and (maybe) update the tier.
Returns a snapshot of curriculum state useful for logging.
"""
self._window.append(1.0 if success else 0.0)
self.state.rollouts_at_current += 1
self.state.successes_at_current += int(bool(success))
snap = self._maybe_step()
self.state.history.append(snap)
return snap
def to_dict(self) -> Dict:
return {
"config": asdict(self.config),
"state": asdict(self.state),
"rolling_success": self.rolling_success(),
}
def rolling_success(self) -> float:
if not self._window:
return 0.0
return float(sum(self._window) / len(self._window))
def save(self, path: Path) -> None:
Path(path).write_text(json.dumps(self.to_dict(), indent=2))
# ── internals ─────────────────────────────────────────────────────
def _maybe_step(self) -> Dict[str, float]:
cfg = self.config
rs = self.rolling_success()
snap: Dict[str, float] = {
"current": self.state.current,
"rolling_success": rs,
"rollouts_at_current": float(self.state.rollouts_at_current),
}
if self.state.rollouts_at_current < cfg.min_rollouts_per_tier:
return snap
if len(self._window) < min(cfg.window, cfg.min_rollouts_per_tier):
return snap
idx = TIERS.index(self.state.current)
if rs >= cfg.promote_threshold and idx < len(TIERS) - 1:
target_idx = idx + 1 if cfg.one_step_at_a_time else len(TIERS) - 1
self._move_to(TIERS[target_idx], promotion=True)
snap["event"] = "promote"
snap["current"] = self.state.current
return snap
if rs <= cfg.demote_threshold and idx > 0:
target_idx = idx - 1 if cfg.one_step_at_a_time else 0
self._move_to(TIERS[target_idx], promotion=False)
snap["event"] = "demote"
snap["current"] = self.state.current
return snap
return snap
def _move_to(self, tier: str, *, promotion: bool) -> None:
self.state.current = tier
self.state.rollouts_at_current = 0
self.state.successes_at_current = 0
self._window.clear()
if promotion:
self.state.promotions += 1
else:
self.state.demotions += 1
__all__ = [
"TIERS",
"CurriculumConfig",
"CurriculumManager",
"CurriculumState",
]
|