OPENENV_RL_01 / rl /curriculum.py
Siddharaj Shirke
deploy: fresh snapshot to Hugging Face Space
3eae4cc
"""
Curriculum scheduler for staged training.
Stage 1 (0-30%) : Easy only
Stage 2 (30-70%) : Easy + Medium (50/50)
Stage 3 (70-100%): All 3 tasks (20/40/40 weights)
"""
from __future__ import annotations
import random
from dataclasses import dataclass
from typing import Tuple
TASK_EASY = "district_backlog_easy"
TASK_MEDIUM = "mixed_urgency_medium"
TASK_HARD = "cross_department_hard"
ALL_TASKS = [TASK_EASY, TASK_MEDIUM, TASK_HARD]
@dataclass
class CurriculumConfig:
stage1_end_frac: float = 0.30
stage2_end_frac: float = 0.70
stage3_weights: Tuple[float, ...] = (0.20, 0.40, 0.40)
class CurriculumScheduler:
"""
Selects task_id for next training episode based on training progress.
"""
def __init__(
self,
total_timesteps: int,
config: CurriculumConfig | None = None,
rng_seed: int = 0,
):
self.total_timesteps = total_timesteps
self.cfg = config or CurriculumConfig()
self._rng = random.Random(rng_seed)
def sample_task(self, current_timestep: int) -> str:
progress = current_timestep / max(self.total_timesteps, 1)
if progress < self.cfg.stage1_end_frac:
return TASK_EASY
elif progress < self.cfg.stage2_end_frac:
return self._rng.choice([TASK_EASY, TASK_MEDIUM])
else:
return self._rng.choices(ALL_TASKS, weights=list(self.cfg.stage3_weights), k=1)[0]
def current_stage(self, current_timestep: int) -> int:
progress = current_timestep / max(self.total_timesteps, 1)
if progress < self.cfg.stage1_end_frac:
return 1
elif progress < self.cfg.stage2_end_frac:
return 2
return 3