| """Adaptive curriculum for CERNenv training.
|
|
|
| The hackathon participant guide (§6, §14) and the FAQs (Q14, Q35) all push
|
| the same idea: don't start RL on the hardest distribution because if the
|
| probability of a successful rollout is ≈ 0, the policy never sees positive
|
| reward and learning stalls. Conversely, if the env is too easy, the
|
| gradient signal collapses and training plateaus (RLVE, FAQ Q46).
|
|
|
| This module exposes a tiny ``CurriculumManager`` that:
|
|
|
| * tracks rolling success rate across rollouts,
|
| * promotes the agent up the difficulty ladder when the policy is clearly
|
| succeeding (rolling success ≥ ``promote_threshold``),
|
| * demotes the agent when the policy is repeatedly failing (rolling
|
| success ≤ ``demote_threshold``) so it can recover instead of burning
|
| compute on impossible tasks,
|
| * exposes a ``next_difficulty()`` hook that the training loop calls
|
| whenever it resets the environment.
|
|
|
| The state is intentionally minimal so it can be serialised into a JSON
|
| checkpoint or logged into ``evidence/training_log.csv``.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import json
|
| from collections import deque
|
| from dataclasses import asdict, dataclass, field
|
| from pathlib import Path
|
| from typing import Deque, Dict, List, Optional
|
|
|
|
|
|
|
|
|
| TIERS: List[str] = ["easy", "medium", "hard"]
|
|
|
|
|
| @dataclass
|
| class CurriculumConfig:
|
| """Knobs for the adaptive curriculum."""
|
|
|
| start_difficulty: str = "easy"
|
| window: int = 24
|
| promote_threshold: float = 0.55
|
| demote_threshold: float = 0.10
|
|
|
|
|
| min_rollouts_per_tier: int = 12
|
|
|
|
|
| one_step_at_a_time: bool = True
|
|
|
|
|
| @dataclass
|
| class CurriculumState:
|
| """Serialisable curriculum state for logging / resume."""
|
|
|
| current: str = "easy"
|
| history: List[Dict[str, float]] = field(default_factory=list)
|
| rollouts_at_current: int = 0
|
| successes_at_current: int = 0
|
| promotions: int = 0
|
| demotions: int = 0
|
|
|
|
|
| class CurriculumManager:
|
| """Rolling-window adaptive curriculum.
|
|
|
| Usage::
|
|
|
| cm = CurriculumManager(CurriculumConfig())
|
| for ep in range(N):
|
| diff = cm.next_difficulty()
|
| obs = env.reset(seed=ep, difficulty=diff)
|
| ... # rollout
|
| cm.record(success=ep_success, reward=cumulative_reward)
|
|
|
| The manager is single-threaded; if you parallelise rollouts wrap
|
| ``record`` calls in a lock at the caller.
|
| """
|
|
|
| def __init__(self, config: CurriculumConfig = CurriculumConfig()) -> None:
|
| if config.start_difficulty not in TIERS:
|
| raise ValueError(
|
| f"start_difficulty must be one of {TIERS}, got {config.start_difficulty!r}"
|
| )
|
| self.config = config
|
| self.state = CurriculumState(current=config.start_difficulty)
|
| self._window: Deque[float] = deque(maxlen=config.window)
|
|
|
|
|
|
|
| def next_difficulty(self) -> str:
|
| """Difficulty to use for the next ``env.reset``."""
|
| return self.state.current
|
|
|
| def record(self, *, success: bool, reward: float = 0.0) -> Dict[str, float]:
|
| """Record one rollout outcome and (maybe) update the tier.
|
|
|
| Returns a snapshot of curriculum state useful for logging.
|
| """
|
| self._window.append(1.0 if success else 0.0)
|
| self.state.rollouts_at_current += 1
|
| self.state.successes_at_current += int(bool(success))
|
| snap = self._maybe_step()
|
| self.state.history.append(snap)
|
| return snap
|
|
|
| def to_dict(self) -> Dict:
|
| return {
|
| "config": asdict(self.config),
|
| "state": asdict(self.state),
|
| "rolling_success": self.rolling_success(),
|
| }
|
|
|
| def rolling_success(self) -> float:
|
| if not self._window:
|
| return 0.0
|
| return float(sum(self._window) / len(self._window))
|
|
|
| def save(self, path: Path) -> None:
|
| Path(path).write_text(json.dumps(self.to_dict(), indent=2))
|
|
|
|
|
|
|
| def _maybe_step(self) -> Dict[str, float]:
|
| cfg = self.config
|
| rs = self.rolling_success()
|
| snap: Dict[str, float] = {
|
| "current": self.state.current,
|
| "rolling_success": rs,
|
| "rollouts_at_current": float(self.state.rollouts_at_current),
|
| }
|
|
|
| if self.state.rollouts_at_current < cfg.min_rollouts_per_tier:
|
| return snap
|
| if len(self._window) < min(cfg.window, cfg.min_rollouts_per_tier):
|
| return snap
|
|
|
| idx = TIERS.index(self.state.current)
|
|
|
| if rs >= cfg.promote_threshold and idx < len(TIERS) - 1:
|
| target_idx = idx + 1 if cfg.one_step_at_a_time else len(TIERS) - 1
|
| self._move_to(TIERS[target_idx], promotion=True)
|
| snap["event"] = "promote"
|
| snap["current"] = self.state.current
|
| return snap
|
|
|
| if rs <= cfg.demote_threshold and idx > 0:
|
| target_idx = idx - 1 if cfg.one_step_at_a_time else 0
|
| self._move_to(TIERS[target_idx], promotion=False)
|
| snap["event"] = "demote"
|
| snap["current"] = self.state.current
|
| return snap
|
|
|
| return snap
|
|
|
| def _move_to(self, tier: str, *, promotion: bool) -> None:
|
| self.state.current = tier
|
| self.state.rollouts_at_current = 0
|
| self.state.successes_at_current = 0
|
| self._window.clear()
|
| if promotion:
|
| self.state.promotions += 1
|
| else:
|
| self.state.demotions += 1
|
|
|
|
|
| __all__ = [
|
| "TIERS",
|
| "CurriculumConfig",
|
| "CurriculumManager",
|
| "CurriculumState",
|
| ]
|
|
|