File size: 1,720 Bytes
df97e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
Curriculum scheduler for staged training.

Stage 1 (0-30%)  : Easy only
Stage 2 (30-70%) : Easy + Medium (50/50)
Stage 3 (70-100%): All 3 tasks (20/40/40 weights)
"""

from __future__ import annotations

import random
from dataclasses import dataclass
from typing import Tuple

TASK_EASY   = "district_backlog_easy"
TASK_MEDIUM = "mixed_urgency_medium"
TASK_HARD   = "cross_department_hard"
ALL_TASKS   = [TASK_EASY, TASK_MEDIUM, TASK_HARD]


@dataclass
class CurriculumConfig:
    stage1_end_frac: float = 0.30
    stage2_end_frac: float = 0.70
    stage3_weights: Tuple[float, ...] = (0.20, 0.40, 0.40)


class CurriculumScheduler:
    """
    Selects task_id for next training episode based on training progress.
    """

    def __init__(
        self,
        total_timesteps: int,
        config: CurriculumConfig | None = None,
        rng_seed: int = 0,
    ):
        self.total_timesteps = total_timesteps
        self.cfg  = config or CurriculumConfig()
        self._rng = random.Random(rng_seed)

    def sample_task(self, current_timestep: int) -> str:
        progress = current_timestep / max(self.total_timesteps, 1)
        if progress < self.cfg.stage1_end_frac:
            return TASK_EASY
        elif progress < self.cfg.stage2_end_frac:
            return self._rng.choice([TASK_EASY, TASK_MEDIUM])
        else:
            return self._rng.choices(ALL_TASKS, weights=list(self.cfg.stage3_weights), k=1)[0]

    def current_stage(self, current_timestep: int) -> int:
        progress = current_timestep / max(self.total_timesteps, 1)
        if progress < self.cfg.stage1_end_frac:
            return 1
        elif progress < self.cfg.stage2_end_frac:
            return 2
        return 3