cernenv / training /curriculum.py
anugrahhu's picture
Update CERNenv Space
f28409b verified
"""Adaptive curriculum for CERNenv training.
The hackathon participant guide (§6, §14) and the FAQs (Q14, Q35) all push
the same idea: don't start RL on the hardest distribution because if the
probability of a successful rollout is ≈ 0, the policy never sees positive
reward and learning stalls. Conversely, if the env is too easy, the
gradient signal collapses and training plateaus (RLVE, FAQ Q46).
This module exposes a tiny ``CurriculumManager`` that:
* tracks rolling success rate across rollouts,
* promotes the agent up the difficulty ladder when the policy is clearly
succeeding (rolling success ≥ ``promote_threshold``),
* demotes the agent when the policy is repeatedly failing (rolling
success ≤ ``demote_threshold``) so it can recover instead of burning
compute on impossible tasks,
* exposes a ``next_difficulty()`` hook that the training loop calls
whenever it resets the environment.
The state is intentionally minimal so it can be serialised into a JSON
checkpoint or logged into ``evidence/training_log.csv``.
"""
from __future__ import annotations
import json
from collections import deque
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Deque, Dict, List, Optional
# Difficulty tiers, in order of increasing challenge. ``None`` is treated
# as "let the sampler choose at random" and is not part of the ladder.
TIERS: List[str] = ["easy", "medium", "hard"]
@dataclass
class CurriculumConfig:
"""Knobs for the adaptive curriculum."""
start_difficulty: str = "easy"
window: int = 24
promote_threshold: float = 0.55
demote_threshold: float = 0.10
# Minimum rollouts at the current tier before any promotion can fire.
# Prevents flapping when the rolling window is still short.
min_rollouts_per_tier: int = 12
# If True, advance one tier at a time; otherwise jump to the highest
# tier the policy clears.
one_step_at_a_time: bool = True
@dataclass
class CurriculumState:
"""Serialisable curriculum state for logging / resume."""
current: str = "easy"
history: List[Dict[str, float]] = field(default_factory=list)
rollouts_at_current: int = 0
successes_at_current: int = 0
promotions: int = 0
demotions: int = 0
class CurriculumManager:
"""Rolling-window adaptive curriculum.
Usage::
cm = CurriculumManager(CurriculumConfig())
for ep in range(N):
diff = cm.next_difficulty()
obs = env.reset(seed=ep, difficulty=diff)
... # rollout
cm.record(success=ep_success, reward=cumulative_reward)
The manager is single-threaded; if you parallelise rollouts wrap
``record`` calls in a lock at the caller.
"""
def __init__(self, config: CurriculumConfig = CurriculumConfig()) -> None:
if config.start_difficulty not in TIERS:
raise ValueError(
f"start_difficulty must be one of {TIERS}, got {config.start_difficulty!r}"
)
self.config = config
self.state = CurriculumState(current=config.start_difficulty)
self._window: Deque[float] = deque(maxlen=config.window)
# ── public API ────────────────────────────────────────────────────
def next_difficulty(self) -> str:
"""Difficulty to use for the next ``env.reset``."""
return self.state.current
def record(self, *, success: bool, reward: float = 0.0) -> Dict[str, float]:
"""Record one rollout outcome and (maybe) update the tier.
Returns a snapshot of curriculum state useful for logging.
"""
self._window.append(1.0 if success else 0.0)
self.state.rollouts_at_current += 1
self.state.successes_at_current += int(bool(success))
snap = self._maybe_step()
self.state.history.append(snap)
return snap
def to_dict(self) -> Dict:
return {
"config": asdict(self.config),
"state": asdict(self.state),
"rolling_success": self.rolling_success(),
}
def rolling_success(self) -> float:
if not self._window:
return 0.0
return float(sum(self._window) / len(self._window))
def save(self, path: Path) -> None:
Path(path).write_text(json.dumps(self.to_dict(), indent=2))
# ── internals ─────────────────────────────────────────────────────
def _maybe_step(self) -> Dict[str, float]:
cfg = self.config
rs = self.rolling_success()
snap: Dict[str, float] = {
"current": self.state.current,
"rolling_success": rs,
"rollouts_at_current": float(self.state.rollouts_at_current),
}
if self.state.rollouts_at_current < cfg.min_rollouts_per_tier:
return snap
if len(self._window) < min(cfg.window, cfg.min_rollouts_per_tier):
return snap
idx = TIERS.index(self.state.current)
if rs >= cfg.promote_threshold and idx < len(TIERS) - 1:
target_idx = idx + 1 if cfg.one_step_at_a_time else len(TIERS) - 1
self._move_to(TIERS[target_idx], promotion=True)
snap["event"] = "promote"
snap["current"] = self.state.current
return snap
if rs <= cfg.demote_threshold and idx > 0:
target_idx = idx - 1 if cfg.one_step_at_a_time else 0
self._move_to(TIERS[target_idx], promotion=False)
snap["event"] = "demote"
snap["current"] = self.state.current
return snap
return snap
def _move_to(self, tier: str, *, promotion: bool) -> None:
self.state.current = tier
self.state.rollouts_at_current = 0
self.state.successes_at_current = 0
self._window.clear()
if promotion:
self.state.promotions += 1
else:
self.state.demotions += 1
__all__ = [
"TIERS",
"CurriculumConfig",
"CurriculumManager",
"CurriculumState",
]