Cyber_analyst-round1 / server /curriculum.py
Humanlearning's picture
feat: enhance scenario authoring and caching mechanisms, update action submission terminology, and improve reward configuration for CyberSecurity_OWASP environment
be8eade
"""Runtime curriculum controller for closed-loop scenario selection."""
from __future__ import annotations
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Any
try:
from ..config import ScenarioAuthoringSettings, load_scenario_authoring_config
from ..models import CyberSecurityOWASPState
except ImportError: # pragma: no cover
from config import ScenarioAuthoringSettings, load_scenario_authoring_config
from models import CyberSecurityOWASPState
DIFFICULTY_TIERS = ("D0", "D1", "D2", "D3")
WEAKNESS_TARGETS = (
"same_role_cross_object",
"cross_tenant_boundary",
"public_route_overlock",
"alternate_route_same_service",
"visible_test_edge_case",
)
@dataclass
class CurriculumController:
"""Tracks episode outcomes and picks the next bounded weakness target."""
window_size: int = 10
reward_trend: deque[float] = field(default_factory=lambda: deque(maxlen=10))
outcomes_by_target: dict[str, list[bool]] = field(default_factory=lambda: defaultdict(list))
failures_by_target: dict[str, int] = field(default_factory=lambda: defaultdict(int))
episodes_seen: int = 0
settings: ScenarioAuthoringSettings = field(default_factory=load_scenario_authoring_config)
def select_profile(
self,
*,
seed: int,
split: str = "train",
requested_difficulty: int = 0,
) -> dict[str, Any]:
difficulty = self._difficulty_for_split(split, requested_difficulty)
target = self._target_for_seed(seed, split)
if self.failures_by_target:
target = max(
WEAKNESS_TARGETS,
key=lambda item: (self.failures_by_target.get(item, 0), -WEAKNESS_TARGETS.index(item)),
)
return {
"difficulty": difficulty,
"difficulty_tier": self._difficulty_label(difficulty),
"target_weakness": target,
"split": split,
"episodes_seen": self.episodes_seen,
"recent_reward_mean": self._recent_reward_mean(),
"mastery": self.mastery_snapshot(),
}
def record_episode(self, state: CyberSecurityOWASPState) -> dict[str, Any]:
target = state.target_weakness or "same_role_cross_object"
success = bool(state.success)
self.episodes_seen += 1
self.outcomes_by_target[target].append(success)
if not success:
self.failures_by_target[target] += 1
self.reward_trend.append(float(state.last_reward or 0.0))
return self.mastery_snapshot()
def mastery_snapshot(self) -> dict[str, Any]:
target_mastery = {}
for target in WEAKNESS_TARGETS:
outcomes = self.outcomes_by_target.get(target, [])
target_mastery[target] = {
"episodes": len(outcomes),
"success_rate": sum(1 for item in outcomes if item) / max(1, len(outcomes)),
"failures": self.failures_by_target.get(target, 0),
}
return {
"episodes_seen": self.episodes_seen,
"recent_reward_mean": self._recent_reward_mean(),
"target_mastery": target_mastery,
}
def _difficulty_for_split(self, split: str, requested_difficulty: int) -> int:
max_difficulty = self.settings.curriculum.difficulty_bucket_count - 1
difficulty = max(0, min(int(requested_difficulty), max_difficulty))
if split == "hidden_eval":
return max(min(3, max_difficulty), difficulty)
if self.episodes_seen >= self.window_size and self._recent_reward_mean() > 10.0:
return min(difficulty + 1, max_difficulty)
return difficulty
def _target_for_seed(self, seed: int, split: str) -> str:
offset = 2 if split == "hidden_eval" else 0
return WEAKNESS_TARGETS[(int(seed) + offset) % len(WEAKNESS_TARGETS)]
def _recent_reward_mean(self) -> float:
if not self.reward_trend:
return 0.0
return sum(self.reward_trend) / len(self.reward_trend)
def _difficulty_label(self, difficulty: int) -> str:
labels = self.settings.curriculum.difficulty_labels
return labels[min(max(0, difficulty), len(labels) - 1)]