"""Adaptive curriculum for CERNenv training. The hackathon participant guide (§6, §14) and the FAQs (Q14, Q35) all push the same idea: don't start RL on the hardest distribution because if the probability of a successful rollout is ≈ 0, the policy never sees positive reward and learning stalls. Conversely, if the env is too easy, the gradient signal collapses and training plateaus (RLVE, FAQ Q46). This module exposes a tiny ``CurriculumManager`` that: * tracks rolling success rate across rollouts, * promotes the agent up the difficulty ladder when the policy is clearly succeeding (rolling success ≥ ``promote_threshold``), * demotes the agent when the policy is repeatedly failing (rolling success ≤ ``demote_threshold``) so it can recover instead of burning compute on impossible tasks, * exposes a ``next_difficulty()`` hook that the training loop calls whenever it resets the environment. The state is intentionally minimal so it can be serialised into a JSON checkpoint or logged into ``evidence/training_log.csv``. """ from __future__ import annotations import json from collections import deque from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Deque, Dict, List, Optional # Difficulty tiers, in order of increasing challenge. ``None`` is treated # as "let the sampler choose at random" and is not part of the ladder. TIERS: List[str] = ["easy", "medium", "hard"] @dataclass class CurriculumConfig: """Knobs for the adaptive curriculum.""" start_difficulty: str = "easy" window: int = 24 promote_threshold: float = 0.55 demote_threshold: float = 0.10 # Minimum rollouts at the current tier before any promotion can fire. # Prevents flapping when the rolling window is still short. min_rollouts_per_tier: int = 12 # If True, advance one tier at a time; otherwise jump to the highest # tier the policy clears. one_step_at_a_time: bool = True @dataclass class CurriculumState: """Serialisable curriculum state for logging / resume.""" current: str = "easy" history: List[Dict[str, float]] = field(default_factory=list) rollouts_at_current: int = 0 successes_at_current: int = 0 promotions: int = 0 demotions: int = 0 class CurriculumManager: """Rolling-window adaptive curriculum. Usage:: cm = CurriculumManager(CurriculumConfig()) for ep in range(N): diff = cm.next_difficulty() obs = env.reset(seed=ep, difficulty=diff) ... # rollout cm.record(success=ep_success, reward=cumulative_reward) The manager is single-threaded; if you parallelise rollouts wrap ``record`` calls in a lock at the caller. """ def __init__(self, config: CurriculumConfig = CurriculumConfig()) -> None: if config.start_difficulty not in TIERS: raise ValueError( f"start_difficulty must be one of {TIERS}, got {config.start_difficulty!r}" ) self.config = config self.state = CurriculumState(current=config.start_difficulty) self._window: Deque[float] = deque(maxlen=config.window) # ── public API ──────────────────────────────────────────────────── def next_difficulty(self) -> str: """Difficulty to use for the next ``env.reset``.""" return self.state.current def record(self, *, success: bool, reward: float = 0.0) -> Dict[str, float]: """Record one rollout outcome and (maybe) update the tier. Returns a snapshot of curriculum state useful for logging. """ self._window.append(1.0 if success else 0.0) self.state.rollouts_at_current += 1 self.state.successes_at_current += int(bool(success)) snap = self._maybe_step() self.state.history.append(snap) return snap def to_dict(self) -> Dict: return { "config": asdict(self.config), "state": asdict(self.state), "rolling_success": self.rolling_success(), } def rolling_success(self) -> float: if not self._window: return 0.0 return float(sum(self._window) / len(self._window)) def save(self, path: Path) -> None: Path(path).write_text(json.dumps(self.to_dict(), indent=2)) # ── internals ───────────────────────────────────────────────────── def _maybe_step(self) -> Dict[str, float]: cfg = self.config rs = self.rolling_success() snap: Dict[str, float] = { "current": self.state.current, "rolling_success": rs, "rollouts_at_current": float(self.state.rollouts_at_current), } if self.state.rollouts_at_current < cfg.min_rollouts_per_tier: return snap if len(self._window) < min(cfg.window, cfg.min_rollouts_per_tier): return snap idx = TIERS.index(self.state.current) if rs >= cfg.promote_threshold and idx < len(TIERS) - 1: target_idx = idx + 1 if cfg.one_step_at_a_time else len(TIERS) - 1 self._move_to(TIERS[target_idx], promotion=True) snap["event"] = "promote" snap["current"] = self.state.current return snap if rs <= cfg.demote_threshold and idx > 0: target_idx = idx - 1 if cfg.one_step_at_a_time else 0 self._move_to(TIERS[target_idx], promotion=False) snap["event"] = "demote" snap["current"] = self.state.current return snap return snap def _move_to(self, tier: str, *, promotion: bool) -> None: self.state.current = tier self.state.rollouts_at_current = 0 self.state.successes_at_current = 0 self._window.clear() if promotion: self.state.promotions += 1 else: self.state.demotions += 1 __all__ = [ "TIERS", "CurriculumConfig", "CurriculumManager", "CurriculumState", ]