File size: 4,262 Bytes
6abc8c5
 
 
 
 
 
 
 
 
be8eade
6abc8c5
 
be8eade
6abc8c5
 
 
be8eade
6abc8c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8eade
6abc8c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8eade
6abc8c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8eade
 
6abc8c5
be8eade
6abc8c5
be8eade
6abc8c5
 
 
 
 
 
 
 
 
 
be8eade
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Runtime curriculum controller for closed-loop scenario selection."""

from __future__ import annotations

from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Any

try:
    from ..config import ScenarioAuthoringSettings, load_scenario_authoring_config
    from ..models import CyberSecurityOWASPState
except ImportError:  # pragma: no cover
    from config import ScenarioAuthoringSettings, load_scenario_authoring_config
    from models import CyberSecurityOWASPState


DIFFICULTY_TIERS = ("D0", "D1", "D2", "D3")
WEAKNESS_TARGETS = (
    "same_role_cross_object",
    "cross_tenant_boundary",
    "public_route_overlock",
    "alternate_route_same_service",
    "visible_test_edge_case",
)


@dataclass
class CurriculumController:
    """Tracks episode outcomes and picks the next bounded weakness target."""

    window_size: int = 10
    reward_trend: deque[float] = field(default_factory=lambda: deque(maxlen=10))
    outcomes_by_target: dict[str, list[bool]] = field(default_factory=lambda: defaultdict(list))
    failures_by_target: dict[str, int] = field(default_factory=lambda: defaultdict(int))
    episodes_seen: int = 0
    settings: ScenarioAuthoringSettings = field(default_factory=load_scenario_authoring_config)

    def select_profile(
        self,
        *,
        seed: int,
        split: str = "train",
        requested_difficulty: int = 0,
    ) -> dict[str, Any]:
        difficulty = self._difficulty_for_split(split, requested_difficulty)
        target = self._target_for_seed(seed, split)
        if self.failures_by_target:
            target = max(
                WEAKNESS_TARGETS,
                key=lambda item: (self.failures_by_target.get(item, 0), -WEAKNESS_TARGETS.index(item)),
            )
        return {
            "difficulty": difficulty,
            "difficulty_tier": self._difficulty_label(difficulty),
            "target_weakness": target,
            "split": split,
            "episodes_seen": self.episodes_seen,
            "recent_reward_mean": self._recent_reward_mean(),
            "mastery": self.mastery_snapshot(),
        }

    def record_episode(self, state: CyberSecurityOWASPState) -> dict[str, Any]:
        target = state.target_weakness or "same_role_cross_object"
        success = bool(state.success)
        self.episodes_seen += 1
        self.outcomes_by_target[target].append(success)
        if not success:
            self.failures_by_target[target] += 1
        self.reward_trend.append(float(state.last_reward or 0.0))
        return self.mastery_snapshot()

    def mastery_snapshot(self) -> dict[str, Any]:
        target_mastery = {}
        for target in WEAKNESS_TARGETS:
            outcomes = self.outcomes_by_target.get(target, [])
            target_mastery[target] = {
                "episodes": len(outcomes),
                "success_rate": sum(1 for item in outcomes if item) / max(1, len(outcomes)),
                "failures": self.failures_by_target.get(target, 0),
            }
        return {
            "episodes_seen": self.episodes_seen,
            "recent_reward_mean": self._recent_reward_mean(),
            "target_mastery": target_mastery,
        }

    def _difficulty_for_split(self, split: str, requested_difficulty: int) -> int:
        max_difficulty = self.settings.curriculum.difficulty_bucket_count - 1
        difficulty = max(0, min(int(requested_difficulty), max_difficulty))
        if split == "hidden_eval":
            return max(min(3, max_difficulty), difficulty)
        if self.episodes_seen >= self.window_size and self._recent_reward_mean() > 10.0:
            return min(difficulty + 1, max_difficulty)
        return difficulty

    def _target_for_seed(self, seed: int, split: str) -> str:
        offset = 2 if split == "hidden_eval" else 0
        return WEAKNESS_TARGETS[(int(seed) + offset) % len(WEAKNESS_TARGETS)]

    def _recent_reward_mean(self) -> float:
        if not self.reward_trend:
            return 0.0
        return sum(self.reward_trend) / len(self.reward_trend)

    def _difficulty_label(self, difficulty: int) -> str:
        labels = self.settings.curriculum.difficulty_labels
        return labels[min(max(0, difficulty), len(labels) - 1)]