Cyber_analyst-round1 / training /grpo_curriculum.py
Humanlearning's picture
feat: enhance CyberSecurity_OWASP observation model with scenario prompt, improve GRPO batch configuration validation, and add scenario grouping for adaptive difficulty curriculum
632c145
"""Scenario grouping and adaptive curriculum helpers for GRPO training."""
from __future__ import annotations
import random
import threading
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any
@dataclass
class AdaptiveDifficultyCurriculum:
min_level: int = 0
max_level: int = 3
current_level: int = 0
promote_after: int = 50
promote_threshold: float = 0.70
demote_threshold: float = 0.35
ema_alpha: float = 0.10
rng_seed: int = 0
counts: dict[int, int] = field(default_factory=dict)
ema_success: dict[int, float] = field(default_factory=dict)
def __post_init__(self) -> None:
self.min_level = int(self.min_level)
self.max_level = int(self.max_level)
self.current_level = max(self.min_level, min(int(self.current_level), self.max_level))
self._rng = random.Random(int(self.rng_seed))
def sample_difficulty(self, available_difficulties: Iterable[int]) -> int:
available = {int(item) for item in available_difficulties}
if not available:
raise ValueError("No cached difficulties are available for GRPO curriculum sampling.")
candidates = [
max(self.min_level, self.current_level - 1),
self.current_level,
min(self.max_level, self.current_level + 1),
]
weights = [0.20, 0.65, 0.15]
weighted: dict[int, float] = {}
for level, weight in zip(candidates, weights):
if level in available:
weighted[level] = weighted.get(level, 0.0) + weight
if not weighted:
nearest = min(available, key=lambda level: (abs(level - self.current_level), level))
return nearest
levels = list(weighted)
return int(self._rng.choices(levels, weights=[weighted[level] for level in levels], k=1)[0])
def update(self, difficulty: int, success: float | bool) -> dict[str, Any]:
level = int(difficulty)
value = max(0.0, min(1.0, float(success)))
self.counts[level] = self.counts.get(level, 0) + 1
old = self.ema_success.get(level, 0.0)
self.ema_success[level] = (1.0 - self.ema_alpha) * old + self.ema_alpha * value
if level == self.current_level and self.counts[level] >= self.promote_after:
if self.ema_success[level] >= self.promote_threshold:
self.current_level = min(self.max_level, self.current_level + 1)
elif self.ema_success[level] <= self.demote_threshold:
self.current_level = max(self.min_level, self.current_level - 1)
return self.snapshot()
def snapshot(self) -> dict[str, Any]:
return {
"current_level": self.current_level,
"counts": {str(key): value for key, value in sorted(self.counts.items())},
"ema_success": {
str(key): value for key, value in sorted(self.ema_success.items())
},
"current_level_ema_success": self.ema_success.get(self.current_level, 0.0),
}
class ScenarioGroupRegistry:
"""Assign each GRPO group to exactly one cached scenario."""
def __init__(
self,
entries: Sequence[Mapping[str, Any]],
*,
split: str = "train",
initial_difficulty: int = 0,
rng_seed: int = 0,
max_level: int | None = None,
) -> None:
self.split = split
self._rng = random.Random(int(rng_seed))
self._lock = threading.Lock()
self._assignments: dict[int, dict[str, Any]] = {}
self._completed_groups: set[int] = set()
self._entries_by_difficulty: dict[int, list[dict[str, Any]]] = {}
self._cursors: dict[int, int] = {}
for entry in entries:
if entry.get("validated") is not True or entry.get("split") != split:
continue
difficulty = int(entry.get("difficulty", 0))
self._entries_by_difficulty.setdefault(difficulty, []).append(dict(entry))
for difficulty, items in self._entries_by_difficulty.items():
items.sort(key=lambda item: (int(item.get("seed", 0)), str(item.get("scenario_hash", ""))))
self._rng.shuffle(items)
self._cursors[difficulty] = 0
if not self._entries_by_difficulty:
raise ValueError(f"No validated cached scenarios are available for split={split!r}.")
available = sorted(self._entries_by_difficulty)
resolved_max = max_level if max_level is not None else max(available)
self.curriculum = AdaptiveDifficultyCurriculum(
min_level=min(available),
max_level=int(resolved_max),
current_level=int(initial_difficulty),
rng_seed=int(rng_seed),
)
@property
def available_difficulties(self) -> list[int]:
return sorted(self._entries_by_difficulty)
def assignment_for(
self,
*,
scenario_group_id: int,
requested_seed: int | None = None,
requested_difficulty: int | None = None,
split: str | None = None,
difficulty_policy: str = "adaptive",
) -> dict[str, Any]:
group_id = int(scenario_group_id)
with self._lock:
if group_id in self._assignments:
return dict(self._assignments[group_id])
if difficulty_policy == "fixed":
difficulty = int(
requested_difficulty
if requested_difficulty is not None
else self.curriculum.current_level
)
entry = self._find_entry(
seed=requested_seed,
split=split or self.split,
difficulty=difficulty,
) or self._next_entry(difficulty)
else:
difficulty = self.curriculum.sample_difficulty(self.available_difficulties)
entry = self._next_entry(difficulty)
assignment = self._assignment_from_entry(group_id, entry)
self._assignments[group_id] = assignment
return dict(assignment)
def record_group_outcome(self, scenario_group_id: int, success_rate: float) -> dict[str, Any] | None:
group_id = int(scenario_group_id)
with self._lock:
if group_id in self._completed_groups:
return None
self._completed_groups.add(group_id)
assignment = self._assignments.get(group_id)
if not assignment:
return self.curriculum.snapshot()
return self.curriculum.update(
int(assignment["difficulty"]),
max(0.0, min(1.0, float(success_rate))),
)
def metrics(
self,
records: Sequence[Mapping[str, Any]],
*,
unique_trace_count: int,
duplicate_trace_suppressed_count: int,
) -> dict[str, float]:
scenario_hashes = {
str(record.get("scenario_hash") or record.get("scenario_id_hash") or "")
for record in records
if record.get("scenario_hash") or record.get("scenario_id_hash")
}
seeds = {
int(record.get("scenario/seed", record.get("seed", 0)) or 0)
for record in records
}
total = max(1, len(records))
snapshot = self.curriculum.snapshot()
return {
"train/unique_trace_count": float(unique_trace_count),
"train/duplicate_trace_suppressed_count": float(duplicate_trace_suppressed_count),
"train/unique_trace_rate": float(unique_trace_count) / total,
"train/unique_seed_count": float(len(seeds)),
"train/unique_scenario_hash_count": float(len(scenario_hashes)),
"train/curriculum_level": float(snapshot["current_level"]),
"train/curriculum_ema_success": float(snapshot["current_level_ema_success"]),
}
def _find_entry(
self,
*,
seed: int | None,
split: str,
difficulty: int,
) -> dict[str, Any] | None:
if seed is None or split != self.split:
return None
for entry in self._entries_by_difficulty.get(int(difficulty), []):
if int(entry.get("seed", -1)) == int(seed):
return dict(entry)
return None
def _next_entry(self, difficulty: int) -> dict[str, Any]:
level = int(difficulty)
items = self._entries_by_difficulty.get(level)
if not items:
nearest = min(
self.available_difficulties,
key=lambda item: (abs(item - level), item),
)
items = self._entries_by_difficulty[nearest]
level = nearest
cursor = self._cursors.get(level, 0)
self._cursors[level] = cursor + 1
return dict(items[cursor % len(items)])
def _assignment_from_entry(self, group_id: int, entry: Mapping[str, Any]) -> dict[str, Any]:
cache_key = entry.get("cache_key") if isinstance(entry.get("cache_key"), Mapping) else {}
return {
"scenario_group_id": int(group_id),
"seed": int(entry.get("seed", 0)),
"split": str(entry.get("split", self.split)),
"difficulty": int(entry.get("difficulty", 0)),
"scenario_hash": str(entry.get("scenario_hash", "")),
"template_id": str(entry.get("template_id") or cache_key.get("app_family", "")),
"bug_family": str(entry.get("bug_family") or cache_key.get("authz_bug_type", "")),
}
def build_scenario_group_rows(
*,
dataset_size: int,
training_prompt: str,
seed_start: int = 0,
split: str = "train",
difficulty: int = 0,
difficulty_policy: str = "adaptive",
) -> list[dict[str, Any]]:
return [
{
"prompt": [{"role": "user", "content": training_prompt}],
"scenario_group_id": int(seed_start) + index,
"seed": int(seed_start) + index,
"difficulty": int(difficulty),
"split": split,
"difficulty_policy": difficulty_policy,
}
for index in range(int(dataset_size))
]