Spaces:
Sleeping
Sleeping
feat: enhance CyberSecurity_OWASP observation model with scenario prompt, improve GRPO batch configuration validation, and add scenario grouping for adaptive difficulty curriculum
632c145 | """Scenario grouping and adaptive curriculum helpers for GRPO training.""" | |
| from __future__ import annotations | |
| import random | |
| import threading | |
| from collections.abc import Iterable, Mapping, Sequence | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| class AdaptiveDifficultyCurriculum: | |
| min_level: int = 0 | |
| max_level: int = 3 | |
| current_level: int = 0 | |
| promote_after: int = 50 | |
| promote_threshold: float = 0.70 | |
| demote_threshold: float = 0.35 | |
| ema_alpha: float = 0.10 | |
| rng_seed: int = 0 | |
| counts: dict[int, int] = field(default_factory=dict) | |
| ema_success: dict[int, float] = field(default_factory=dict) | |
| def __post_init__(self) -> None: | |
| self.min_level = int(self.min_level) | |
| self.max_level = int(self.max_level) | |
| self.current_level = max(self.min_level, min(int(self.current_level), self.max_level)) | |
| self._rng = random.Random(int(self.rng_seed)) | |
| def sample_difficulty(self, available_difficulties: Iterable[int]) -> int: | |
| available = {int(item) for item in available_difficulties} | |
| if not available: | |
| raise ValueError("No cached difficulties are available for GRPO curriculum sampling.") | |
| candidates = [ | |
| max(self.min_level, self.current_level - 1), | |
| self.current_level, | |
| min(self.max_level, self.current_level + 1), | |
| ] | |
| weights = [0.20, 0.65, 0.15] | |
| weighted: dict[int, float] = {} | |
| for level, weight in zip(candidates, weights): | |
| if level in available: | |
| weighted[level] = weighted.get(level, 0.0) + weight | |
| if not weighted: | |
| nearest = min(available, key=lambda level: (abs(level - self.current_level), level)) | |
| return nearest | |
| levels = list(weighted) | |
| return int(self._rng.choices(levels, weights=[weighted[level] for level in levels], k=1)[0]) | |
| def update(self, difficulty: int, success: float | bool) -> dict[str, Any]: | |
| level = int(difficulty) | |
| value = max(0.0, min(1.0, float(success))) | |
| self.counts[level] = self.counts.get(level, 0) + 1 | |
| old = self.ema_success.get(level, 0.0) | |
| self.ema_success[level] = (1.0 - self.ema_alpha) * old + self.ema_alpha * value | |
| if level == self.current_level and self.counts[level] >= self.promote_after: | |
| if self.ema_success[level] >= self.promote_threshold: | |
| self.current_level = min(self.max_level, self.current_level + 1) | |
| elif self.ema_success[level] <= self.demote_threshold: | |
| self.current_level = max(self.min_level, self.current_level - 1) | |
| return self.snapshot() | |
| def snapshot(self) -> dict[str, Any]: | |
| return { | |
| "current_level": self.current_level, | |
| "counts": {str(key): value for key, value in sorted(self.counts.items())}, | |
| "ema_success": { | |
| str(key): value for key, value in sorted(self.ema_success.items()) | |
| }, | |
| "current_level_ema_success": self.ema_success.get(self.current_level, 0.0), | |
| } | |
| class ScenarioGroupRegistry: | |
| """Assign each GRPO group to exactly one cached scenario.""" | |
| def __init__( | |
| self, | |
| entries: Sequence[Mapping[str, Any]], | |
| *, | |
| split: str = "train", | |
| initial_difficulty: int = 0, | |
| rng_seed: int = 0, | |
| max_level: int | None = None, | |
| ) -> None: | |
| self.split = split | |
| self._rng = random.Random(int(rng_seed)) | |
| self._lock = threading.Lock() | |
| self._assignments: dict[int, dict[str, Any]] = {} | |
| self._completed_groups: set[int] = set() | |
| self._entries_by_difficulty: dict[int, list[dict[str, Any]]] = {} | |
| self._cursors: dict[int, int] = {} | |
| for entry in entries: | |
| if entry.get("validated") is not True or entry.get("split") != split: | |
| continue | |
| difficulty = int(entry.get("difficulty", 0)) | |
| self._entries_by_difficulty.setdefault(difficulty, []).append(dict(entry)) | |
| for difficulty, items in self._entries_by_difficulty.items(): | |
| items.sort(key=lambda item: (int(item.get("seed", 0)), str(item.get("scenario_hash", "")))) | |
| self._rng.shuffle(items) | |
| self._cursors[difficulty] = 0 | |
| if not self._entries_by_difficulty: | |
| raise ValueError(f"No validated cached scenarios are available for split={split!r}.") | |
| available = sorted(self._entries_by_difficulty) | |
| resolved_max = max_level if max_level is not None else max(available) | |
| self.curriculum = AdaptiveDifficultyCurriculum( | |
| min_level=min(available), | |
| max_level=int(resolved_max), | |
| current_level=int(initial_difficulty), | |
| rng_seed=int(rng_seed), | |
| ) | |
| def available_difficulties(self) -> list[int]: | |
| return sorted(self._entries_by_difficulty) | |
| def assignment_for( | |
| self, | |
| *, | |
| scenario_group_id: int, | |
| requested_seed: int | None = None, | |
| requested_difficulty: int | None = None, | |
| split: str | None = None, | |
| difficulty_policy: str = "adaptive", | |
| ) -> dict[str, Any]: | |
| group_id = int(scenario_group_id) | |
| with self._lock: | |
| if group_id in self._assignments: | |
| return dict(self._assignments[group_id]) | |
| if difficulty_policy == "fixed": | |
| difficulty = int( | |
| requested_difficulty | |
| if requested_difficulty is not None | |
| else self.curriculum.current_level | |
| ) | |
| entry = self._find_entry( | |
| seed=requested_seed, | |
| split=split or self.split, | |
| difficulty=difficulty, | |
| ) or self._next_entry(difficulty) | |
| else: | |
| difficulty = self.curriculum.sample_difficulty(self.available_difficulties) | |
| entry = self._next_entry(difficulty) | |
| assignment = self._assignment_from_entry(group_id, entry) | |
| self._assignments[group_id] = assignment | |
| return dict(assignment) | |
| def record_group_outcome(self, scenario_group_id: int, success_rate: float) -> dict[str, Any] | None: | |
| group_id = int(scenario_group_id) | |
| with self._lock: | |
| if group_id in self._completed_groups: | |
| return None | |
| self._completed_groups.add(group_id) | |
| assignment = self._assignments.get(group_id) | |
| if not assignment: | |
| return self.curriculum.snapshot() | |
| return self.curriculum.update( | |
| int(assignment["difficulty"]), | |
| max(0.0, min(1.0, float(success_rate))), | |
| ) | |
| def metrics( | |
| self, | |
| records: Sequence[Mapping[str, Any]], | |
| *, | |
| unique_trace_count: int, | |
| duplicate_trace_suppressed_count: int, | |
| ) -> dict[str, float]: | |
| scenario_hashes = { | |
| str(record.get("scenario_hash") or record.get("scenario_id_hash") or "") | |
| for record in records | |
| if record.get("scenario_hash") or record.get("scenario_id_hash") | |
| } | |
| seeds = { | |
| int(record.get("scenario/seed", record.get("seed", 0)) or 0) | |
| for record in records | |
| } | |
| total = max(1, len(records)) | |
| snapshot = self.curriculum.snapshot() | |
| return { | |
| "train/unique_trace_count": float(unique_trace_count), | |
| "train/duplicate_trace_suppressed_count": float(duplicate_trace_suppressed_count), | |
| "train/unique_trace_rate": float(unique_trace_count) / total, | |
| "train/unique_seed_count": float(len(seeds)), | |
| "train/unique_scenario_hash_count": float(len(scenario_hashes)), | |
| "train/curriculum_level": float(snapshot["current_level"]), | |
| "train/curriculum_ema_success": float(snapshot["current_level_ema_success"]), | |
| } | |
| def _find_entry( | |
| self, | |
| *, | |
| seed: int | None, | |
| split: str, | |
| difficulty: int, | |
| ) -> dict[str, Any] | None: | |
| if seed is None or split != self.split: | |
| return None | |
| for entry in self._entries_by_difficulty.get(int(difficulty), []): | |
| if int(entry.get("seed", -1)) == int(seed): | |
| return dict(entry) | |
| return None | |
| def _next_entry(self, difficulty: int) -> dict[str, Any]: | |
| level = int(difficulty) | |
| items = self._entries_by_difficulty.get(level) | |
| if not items: | |
| nearest = min( | |
| self.available_difficulties, | |
| key=lambda item: (abs(item - level), item), | |
| ) | |
| items = self._entries_by_difficulty[nearest] | |
| level = nearest | |
| cursor = self._cursors.get(level, 0) | |
| self._cursors[level] = cursor + 1 | |
| return dict(items[cursor % len(items)]) | |
| def _assignment_from_entry(self, group_id: int, entry: Mapping[str, Any]) -> dict[str, Any]: | |
| cache_key = entry.get("cache_key") if isinstance(entry.get("cache_key"), Mapping) else {} | |
| return { | |
| "scenario_group_id": int(group_id), | |
| "seed": int(entry.get("seed", 0)), | |
| "split": str(entry.get("split", self.split)), | |
| "difficulty": int(entry.get("difficulty", 0)), | |
| "scenario_hash": str(entry.get("scenario_hash", "")), | |
| "template_id": str(entry.get("template_id") or cache_key.get("app_family", "")), | |
| "bug_family": str(entry.get("bug_family") or cache_key.get("authz_bug_type", "")), | |
| } | |
| def build_scenario_group_rows( | |
| *, | |
| dataset_size: int, | |
| training_prompt: str, | |
| seed_start: int = 0, | |
| split: str = "train", | |
| difficulty: int = 0, | |
| difficulty_policy: str = "adaptive", | |
| ) -> list[dict[str, Any]]: | |
| return [ | |
| { | |
| "prompt": [{"role": "user", "content": training_prompt}], | |
| "scenario_group_id": int(seed_start) + index, | |
| "seed": int(seed_start) + index, | |
| "difficulty": int(difficulty), | |
| "split": split, | |
| "difficulty_policy": difficulty_policy, | |
| } | |
| for index in range(int(dataset_size)) | |
| ] | |